Yingxu He commited on
Commit
2cec09f
·
verified ·
1 Parent(s): e5b7870

Upload modeling_meralion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_meralion.py +1368 -0
modeling_meralion.py ADDED
@@ -0,0 +1,1368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MERaLiON model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import EncoderDecoderCache, StaticCache, HybridCache
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.modeling_outputs import ModelOutput, BaseModelOutput
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import (
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ is_flash_attn_2_available,
34
+ is_flash_attn_greater_or_equal_2_10,
35
+ logging,
36
+ replace_return_docstrings,
37
+ )
38
+
39
+ from .configuration_meralion import MERaLiONConfig, MERaLiONSpeechConfig
40
+ from .modeling_text_decoder import MERaLiONTextForCausalLM
41
+
42
+
43
+ if is_flash_attn_2_available():
44
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "MERaLiONConfig"
50
+
51
+
52
+ def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
53
+ """Returns sinusoids for positional embedding"""
54
+ if channels % 2 != 0:
55
+ raise ValueError(
56
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
57
+ )
58
+ log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
59
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
60
+ scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
61
+ return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
62
+
63
+
64
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
65
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
66
+ """
67
+ Shift input ids one token to the right.
68
+ """
69
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
70
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
71
+ shifted_input_ids[:, 0] = decoder_start_token_id
72
+
73
+ if pad_token_id is None:
74
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
75
+ # replace possible -100 values in labels by `pad_token_id`
76
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
77
+
78
+ return shifted_input_ids
79
+
80
+
81
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
82
+ def _prepare_4d_causal_attention_mask_with_cache_position(
83
+ attention_mask: torch.Tensor,
84
+ sequence_length: int,
85
+ target_length: int,
86
+ dtype: torch.dtype,
87
+ device: torch.device,
88
+ min_dtype: float,
89
+ cache_position: torch.Tensor,
90
+ batch_size: int,
91
+ ):
92
+ """
93
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
94
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
95
+
96
+ Args:
97
+ attention_mask (`torch.Tensor`):
98
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
99
+ sequence_length (`int`):
100
+ The sequence length being processed.
101
+ target_length (`int`):
102
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
103
+ dtype (`torch.dtype`):
104
+ The dtype to use for the 4D attention mask.
105
+ device (`torch.device`):
106
+ The device to plcae the 4D attention mask on.
107
+ min_dtype (`float`):
108
+ The minimum value representable with the dtype `dtype`.
109
+ cache_position (`torch.Tensor`):
110
+ Indices depicting the position of the input sequence tokens in the sequence.
111
+ batch_size (`torch.Tensor`):
112
+ Batch size.
113
+ """
114
+ if attention_mask is not None and attention_mask.dim() == 4:
115
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
116
+ causal_mask = attention_mask
117
+ else:
118
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
119
+ if sequence_length != 1:
120
+ causal_mask = torch.triu(causal_mask, diagonal=1)
121
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
122
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
123
+ if attention_mask is not None:
124
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
125
+ mask_length = attention_mask.shape[-1]
126
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
127
+ padding_mask = padding_mask == 0
128
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
129
+ padding_mask, min_dtype
130
+ )
131
+ return causal_mask
132
+
133
+
134
+ class MERaLiONSpeechAttention(nn.Module):
135
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
136
+
137
+ def __init__(
138
+ self,
139
+ embed_dim: int,
140
+ num_heads: int,
141
+ dropout: float = 0.0,
142
+ is_decoder: bool = False,
143
+ bias: bool = True,
144
+ is_causal: bool = False,
145
+ layer_idx: Optional[int] = None,
146
+ config: Optional[MERaLiONSpeechConfig] = None,
147
+ ):
148
+ super().__init__()
149
+ self.embed_dim = embed_dim
150
+ self.num_heads = num_heads
151
+ self.dropout = dropout
152
+ self.head_dim = embed_dim // num_heads
153
+ self.config = config
154
+
155
+ if (self.head_dim * num_heads) != self.embed_dim:
156
+ raise ValueError(
157
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
158
+ f" and `num_heads`: {num_heads})."
159
+ )
160
+ self.scaling = self.head_dim**-0.5
161
+ self.is_decoder = is_decoder
162
+ self.is_causal = is_causal
163
+
164
+ if layer_idx is None and is_decoder:
165
+ logger.warning_once(
166
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
167
+ "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
168
+ "when creating this class."
169
+ )
170
+ self.layer_idx = layer_idx
171
+
172
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
173
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
174
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
175
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
176
+
177
+ # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->speech
178
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
179
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
180
+
181
+ def forward(
182
+ self,
183
+ hidden_states: torch.Tensor,
184
+ key_value_states: Optional[torch.Tensor] = None,
185
+ past_key_value: Optional[EncoderDecoderCache] = None,
186
+ attention_mask: Optional[torch.Tensor] = None,
187
+ layer_head_mask: Optional[torch.Tensor] = None,
188
+ output_attentions: bool = False,
189
+ cache_position: Optional[torch.LongTensor] = None,
190
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
191
+ """Input shape: Batch x Time x Channel"""
192
+
193
+ # if key_value_states are provided this layer is used as a cross-attention layer
194
+ # for the decoder
195
+ is_cross_attention = key_value_states is not None
196
+ bsz, tgt_len, _ = hidden_states.size()
197
+
198
+ # get query proj
199
+ query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
200
+
201
+ if past_key_value is not None:
202
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
203
+ if is_cross_attention:
204
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
205
+ past_key_value.is_updated[self.layer_idx] = True
206
+ past_key_value = past_key_value.cross_attention_cache
207
+ else:
208
+ past_key_value = past_key_value.self_attention_cache
209
+
210
+ # use key_value_states if cross attention
211
+ current_states = key_value_states if key_value_states is not None else hidden_states
212
+ if is_cross_attention and past_key_value and is_updated:
213
+ # reuse k,v, cross_attentions
214
+ key_states = past_key_value.key_cache[self.layer_idx]
215
+ value_states = past_key_value.value_cache[self.layer_idx]
216
+ else:
217
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
218
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
219
+ if past_key_value is not None:
220
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
221
+ cache_position = cache_position if not is_cross_attention else None
222
+ key_states, value_states = past_key_value.update(
223
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
224
+ )
225
+
226
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
227
+
228
+ if attention_mask is not None: # no matter the length, we just slice it
229
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
230
+ attn_weights = attn_weights + causal_mask
231
+
232
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
233
+
234
+ if layer_head_mask is not None:
235
+ if layer_head_mask.size() != (self.num_heads,):
236
+ raise ValueError(
237
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
238
+ f" {layer_head_mask.size()}"
239
+ )
240
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
241
+
242
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
243
+ attn_output = torch.matmul(attn_probs, value_states)
244
+
245
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
246
+ raise ValueError(
247
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
248
+ f" {attn_output.size()}"
249
+ )
250
+
251
+ attn_output = attn_output.transpose(1, 2)
252
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
253
+ # partitioned across GPUs when using tensor-parallelism.
254
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
255
+
256
+ attn_output = self.out_proj(attn_output)
257
+
258
+ return attn_output, attn_weights, past_key_value
259
+
260
+
261
+ class MERaLiONSpeechFlashAttention2(MERaLiONSpeechAttention):
262
+ """
263
+ MERaLiONSpeech flash attention module. This module inherits from `MERaLiONSpeechAttention` as the weights of the module stays
264
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
265
+ flash attention and deal with padding tokens in case the input contains any of them.
266
+ """
267
+
268
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
269
+ def __init__(self, *args, **kwargs):
270
+ super().__init__(*args, **kwargs)
271
+
272
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
273
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
274
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
275
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
276
+
277
+ def forward(
278
+ self,
279
+ hidden_states: torch.Tensor,
280
+ key_value_states: Optional[torch.Tensor] = None,
281
+ past_key_value: Optional[EncoderDecoderCache] = None,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ layer_head_mask: Optional[torch.Tensor] = None,
284
+ output_attentions: bool = False,
285
+ cache_position: Optional[torch.LongTensor] = None,
286
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
287
+ if isinstance(past_key_value, StaticCache):
288
+ raise ValueError(
289
+ "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
290
+ "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
291
+ )
292
+ # SpeechFlashAttention2 attention does not support output_attentions
293
+ if output_attentions:
294
+ raise ValueError("SpeechFlashAttention2 attention does not support output_attentions")
295
+
296
+ # if key_value_states are provided this layer is used as a cross-attention layer
297
+ # for the decoder
298
+ is_cross_attention = key_value_states is not None
299
+ bsz, tgt_len, _ = hidden_states.size()
300
+
301
+ # get query proj
302
+ query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
303
+
304
+ if past_key_value is not None:
305
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
306
+ if is_cross_attention:
307
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
308
+ past_key_value.is_updated[self.layer_idx] = True
309
+ past_key_value = past_key_value.cross_attention_cache
310
+ else:
311
+ past_key_value = past_key_value.self_attention_cache
312
+
313
+ # use key_value_states if cross attention
314
+ current_states = key_value_states if key_value_states is not None else hidden_states
315
+ if is_cross_attention and past_key_value and is_updated:
316
+ # reuse k,v, cross_attentions
317
+ key_states = past_key_value.key_cache[self.layer_idx]
318
+ value_states = past_key_value.value_cache[self.layer_idx]
319
+ else:
320
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
321
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
322
+ if past_key_value is not None:
323
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
324
+ cache_position = cache_position if not is_cross_attention else None
325
+ key_states, value_states = past_key_value.update(
326
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
327
+ )
328
+
329
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
330
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
331
+ key_states = key_states.transpose(1, 2)
332
+ value_states = value_states.transpose(1, 2)
333
+
334
+ causal_mask = attention_mask
335
+ if attention_mask is not None: # no matter the length, we just slice it
336
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
337
+
338
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
339
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
340
+ # cast them back in the correct dtype just to be sure everything works as expected.
341
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
342
+ # in fp32. (LlamaRMSNorm handles it correctly)
343
+
344
+ input_dtype = query_states.dtype
345
+ if input_dtype == torch.float32:
346
+ if torch.is_autocast_enabled():
347
+ target_dtype = torch.get_autocast_gpu_dtype()
348
+ # Handle the case where the model is quantized
349
+ elif hasattr(self.config, "_pre_quantization_dtype"):
350
+ target_dtype = self.config._pre_quantization_dtype
351
+ else:
352
+ target_dtype = self.q_proj.weight.dtype
353
+
354
+ logger.warning_once(
355
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
356
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
357
+ f" {target_dtype}."
358
+ )
359
+
360
+ query_states = query_states.to(target_dtype)
361
+ key_states = key_states.to(target_dtype)
362
+ value_states = value_states.to(target_dtype)
363
+
364
+ attn_output = _flash_attention_forward(
365
+ query_states,
366
+ key_states,
367
+ value_states,
368
+ causal_mask,
369
+ tgt_len,
370
+ dropout=self.dropout if self.training else 0.0,
371
+ is_causal=self.is_causal,
372
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
373
+ )
374
+
375
+ attn_output = attn_output.reshape(bsz, tgt_len, -1)
376
+ attn_output = self.out_proj(attn_output)
377
+
378
+ if not output_attentions:
379
+ attn_weights = None
380
+
381
+ return attn_output, attn_weights, past_key_value
382
+
383
+
384
+ class MERaLiONSpeechSdpaAttention(MERaLiONSpeechAttention):
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ key_value_states: Optional[torch.Tensor] = None,
389
+ past_key_value: Optional[EncoderDecoderCache] = None,
390
+ attention_mask: Optional[torch.Tensor] = None,
391
+ layer_head_mask: Optional[torch.Tensor] = None,
392
+ output_attentions: bool = False,
393
+ cache_position: Optional[torch.LongTensor] = None,
394
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
395
+ """Input shape: Batch x Time x Channel"""
396
+ if output_attentions or layer_head_mask is not None:
397
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
398
+ logger.warning_once(
399
+ "MERaLiONSpeechModel is using MERaLiONSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
400
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
401
+ )
402
+ return super().forward(
403
+ hidden_states,
404
+ key_value_states=key_value_states,
405
+ past_key_value=past_key_value,
406
+ attention_mask=attention_mask,
407
+ layer_head_mask=layer_head_mask,
408
+ output_attentions=output_attentions,
409
+ cache_position=cache_position,
410
+ )
411
+
412
+ # if key_value_states are provided this layer is used as a cross-attention layer
413
+ # for the decoder
414
+ is_cross_attention = key_value_states is not None
415
+ bsz, tgt_len, _ = hidden_states.size()
416
+
417
+ # get query proj
418
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
419
+
420
+ if past_key_value is not None:
421
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
422
+ if is_cross_attention:
423
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
424
+ past_key_value.is_updated[self.layer_idx] = True
425
+ past_key_value = past_key_value.cross_attention_cache
426
+ else:
427
+ past_key_value = past_key_value.self_attention_cache
428
+
429
+ # use key_value_states if cross attention
430
+ current_states = key_value_states if key_value_states is not None else hidden_states
431
+ if is_cross_attention and past_key_value and is_updated:
432
+ # reuse k,v, cross_attentions
433
+ key_states = past_key_value.key_cache[self.layer_idx]
434
+ value_states = past_key_value.value_cache[self.layer_idx]
435
+ else:
436
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
437
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
438
+ if past_key_value is not None:
439
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
440
+ cache_position = cache_position if not is_cross_attention else None
441
+ key_states, value_states = past_key_value.update(
442
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
443
+ )
444
+
445
+ causal_mask = attention_mask
446
+ if attention_mask is not None: # no matter the length, we just slice it
447
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
448
+
449
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
450
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
451
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
452
+ is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
453
+
454
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
455
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
456
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
457
+ query_states,
458
+ key_states,
459
+ value_states,
460
+ attn_mask=causal_mask,
461
+ dropout_p=self.dropout if self.training else 0.0,
462
+ is_causal=is_causal,
463
+ )
464
+
465
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
466
+ raise ValueError(
467
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
468
+ f" {attn_output.size()}"
469
+ )
470
+
471
+ attn_output = attn_output.transpose(1, 2)
472
+
473
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
474
+ # partitioned across GPUs when using tensor-parallelism.
475
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
476
+
477
+ attn_output = self.out_proj(attn_output)
478
+
479
+ return attn_output, None, past_key_value
480
+
481
+
482
+ MERALION_SPEECH_ATTENTION_CLASSES = {
483
+ "eager": MERaLiONSpeechAttention,
484
+ "flash_attention_2": MERaLiONSpeechFlashAttention2,
485
+ "sdpa": MERaLiONSpeechSdpaAttention,
486
+ }
487
+
488
+
489
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech, MBART->WHISPER
490
+ class MERaLiONSpeechEncoderLayer(nn.Module):
491
+ def __init__(self, config: MERaLiONSpeechConfig):
492
+ super().__init__()
493
+ self.embed_dim = config.d_model
494
+
495
+ self.self_attn = MERALION_SPEECH_ATTENTION_CLASSES[config._attn_implementation](
496
+ embed_dim=self.embed_dim,
497
+ num_heads=config.encoder_attention_heads,
498
+ dropout=config.attention_dropout,
499
+ config=config,
500
+ )
501
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
502
+ self.dropout = config.dropout
503
+ self.activation_fn = ACT2FN[config.activation_function]
504
+ self.activation_dropout = config.activation_dropout
505
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
506
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
507
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
508
+
509
+ def forward(
510
+ self,
511
+ hidden_states: torch.Tensor,
512
+ attention_mask: torch.Tensor,
513
+ layer_head_mask: torch.Tensor,
514
+ output_attentions: bool = False,
515
+ ) -> torch.Tensor:
516
+ """
517
+ Args:
518
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
519
+ attention_mask (`torch.FloatTensor`): attention mask of size
520
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
521
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
522
+ `(encoder_attention_heads,)`.
523
+ output_attentions (`bool`, *optional*):
524
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
525
+ returned tensors for more detail.
526
+ """
527
+ residual = hidden_states
528
+ hidden_states = self.self_attn_layer_norm(hidden_states)
529
+ hidden_states, attn_weights, _ = self.self_attn(
530
+ hidden_states=hidden_states,
531
+ attention_mask=attention_mask,
532
+ layer_head_mask=layer_head_mask,
533
+ output_attentions=output_attentions,
534
+ )
535
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
536
+ hidden_states = residual + hidden_states
537
+
538
+ residual = hidden_states
539
+ hidden_states = self.final_layer_norm(hidden_states)
540
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
541
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
542
+ hidden_states = self.fc2(hidden_states)
543
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
544
+ hidden_states = residual + hidden_states
545
+
546
+ if hidden_states.dtype == torch.float16 and (
547
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
548
+ ):
549
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
550
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
551
+
552
+ outputs = (hidden_states,)
553
+
554
+ if output_attentions:
555
+ outputs += (attn_weights,)
556
+
557
+ return outputs
558
+
559
+
560
+ class MERaLiONSpeechPreTrainedModel(PreTrainedModel):
561
+ config_class = MERaLiONSpeechConfig
562
+ base_model_prefix = "model"
563
+ main_input_name = "input_features"
564
+ supports_gradient_checkpointing = True
565
+ _no_split_modules = ["MERaLiONSpeechEncoderLayer", "MERaLiONSpeechDecoderLayer"]
566
+ _supports_flash_attn_2 = True
567
+ _supports_sdpa = True
568
+ _supports_cache_class = True
569
+ _supports_static_cache = True
570
+
571
+ def _init_weights(self, module):
572
+ std = self.config.init_std
573
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
574
+ module.weight.data.normal_(mean=0.0, std=std)
575
+ if module.bias is not None:
576
+ module.bias.data.zero_()
577
+ elif isinstance(module, nn.Embedding):
578
+ module.weight.data.normal_(mean=0.0, std=std)
579
+ if module.padding_idx is not None:
580
+ module.weight.data[module.padding_idx].zero_()
581
+ elif isinstance(module, MERaLiONSpeechEncoder):
582
+ with torch.no_grad():
583
+ embed_positions = module.embed_positions.weight
584
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
585
+
586
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
587
+ """
588
+ Computes the output length of the convolutional layers
589
+ """
590
+ input_lengths = (input_lengths - 1) // 2 + 1
591
+
592
+ return input_lengths
593
+
594
+
595
+ MERALION_SPEECH_START_DOCSTRING = r"""
596
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
597
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
598
+ etc.)
599
+
600
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
601
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
602
+ and behavior.
603
+
604
+ Parameters:
605
+ config ([`MERaLiONSpeechConfig`]):
606
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
607
+ load the weights associated with the model, only the configuration. Check out the
608
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
609
+ """
610
+
611
+ MERALION_SPEECH_INPUTS_DOCSTRING = r"""
612
+ Args:
613
+ input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
614
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
615
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
616
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
617
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
618
+ tensor of type `torch.FloatTensor`. See [`~SpeechFeatureExtractor.__call__`]
619
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
620
+ Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
621
+ `[0, 1]`:
622
+
623
+ - 1 for tokens that are **not masked**,
624
+ - 0 for tokens that are **masked**.
625
+
626
+ [What are attention masks?](../glossary#attention-mask)
627
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
628
+ Indices of decoder input sequence tokens in the vocabulary.
629
+
630
+ Indices can be obtained using [`SpeechTokenizer`]. See [`PreTrainedTokenizer.encode`] and
631
+ [`PreTrainedTokenizer.__call__`] for details.
632
+
633
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
634
+
635
+ Speech uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
636
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
637
+ `past_key_values`).
638
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
639
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
640
+ be used by default.
641
+
642
+ If you want to change padding behavior, you should read
643
+ [`modeling_speech._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
644
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
645
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
646
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
647
+
648
+ - 1 indicates the head is **not masked**,
649
+ - 0 indicates the head is **masked**.
650
+
651
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
652
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
653
+
654
+ - 1 indicates the head is **not masked**,
655
+ - 0 indicates the head is **masked**.
656
+
657
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
658
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
659
+
660
+ - 1 indicates the head is **not masked**,
661
+ - 0 indicates the head is **masked**.
662
+
663
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
664
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
665
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
666
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
667
+ past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
668
+ Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
669
+ four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
670
+ in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
671
+ when `config.use_cache=True`
672
+
673
+ Two formats are allowed:
674
+ - An [`~cache_utils.EncoderDecoderCache`] instance;
675
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
676
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
677
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
678
+
679
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
680
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
681
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
682
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
683
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
684
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
685
+ input (see `past_key_values`). This is useful if you want more control over how to convert
686
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
687
+ use_cache (`bool`, *optional*):
688
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
689
+ `past_key_values`).
690
+ output_attentions (`bool`, *optional*):
691
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
692
+ tensors for more detail.
693
+ output_hidden_states (`bool`, *optional*):
694
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
695
+ more detail.
696
+ return_dict (`bool`, *optional*):
697
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
698
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
699
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
700
+ in the correct position and to infer the complete sequence length.
701
+ """
702
+
703
+ MERALION_SPEECH_ENCODER_INPUTS_DOCSTRING = r"""
704
+ Args:
705
+ input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
706
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
707
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
708
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
709
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
710
+ tensor of type `torch.FloatTensor`. See [`~SpeechFeatureExtractor.__call__`]
711
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
712
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
713
+
714
+ - 1 indicates the head is **not masked**,
715
+ - 0 indicates the head is **masked**.
716
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
717
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
718
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
719
+ hidden-states at the output of the last layer of the encoder.
720
+ output_attentions (`bool`, *optional*):
721
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
722
+ tensors for more detail.
723
+ output_hidden_states (`bool`, *optional*):
724
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
725
+ more detail.
726
+ return_dict (`bool`, *optional*):
727
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
728
+ """
729
+
730
+
731
+ class MERaLiONSpeechEncoder(MERaLiONSpeechPreTrainedModel):
732
+ """
733
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
734
+ [`MERaLiONSpeechEncoderLayer`].
735
+
736
+ Args:
737
+ config: MERaLiONSpeechConfig
738
+ """
739
+
740
+ def __init__(self, config: MERaLiONSpeechConfig):
741
+ super().__init__(config)
742
+ self.dropout = config.dropout
743
+ self.layerdrop = config.encoder_layerdrop
744
+
745
+ embed_dim = config.d_model
746
+ self.num_mel_bins = config.num_mel_bins
747
+ self.padding_idx = config.pad_token_id
748
+ self.max_source_positions = config.max_source_positions
749
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
750
+
751
+ self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
752
+ self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
753
+
754
+ self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
755
+ self.embed_positions.requires_grad_(False)
756
+
757
+ self.layers = nn.ModuleList([MERaLiONSpeechEncoderLayer(config) for _ in range(config.encoder_layers)])
758
+ self.layer_norm = nn.LayerNorm(config.d_model)
759
+
760
+ self.gradient_checkpointing = False
761
+ # Initialize weights and apply final processing
762
+ self.post_init()
763
+
764
+ def _freeze_parameters(self):
765
+ for param in self.parameters():
766
+ param.requires_grad = False
767
+ self._requires_grad = False
768
+
769
+ def get_input_embeddings(self) -> nn.Module:
770
+ return self.conv1
771
+
772
+ def set_input_embeddings(self, value: nn.Module):
773
+ self.conv1 = value
774
+
775
+ def forward(
776
+ self,
777
+ input_features,
778
+ attention_mask=None,
779
+ head_mask=None,
780
+ output_attentions=None,
781
+ output_hidden_states=None,
782
+ return_dict=None,
783
+ ):
784
+ r"""
785
+ Args:
786
+ input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
787
+ Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
788
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
789
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
790
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
791
+ and conversion into a tensor of type `torch.FloatTensor`. See [`~SpeechFeatureExtractor.__call__`]
792
+ attention_mask (`torch.Tensor`)`, *optional*):
793
+ Speech does not support masking of the `input_features`, this argument is preserved for compatibility,
794
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
795
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
796
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
797
+
798
+ - 1 indicates the head is **not masked**,
799
+ - 0 indicates the head is **masked**.
800
+ output_attentions (`bool`, *optional*):
801
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
802
+ returned tensors for more detail.
803
+ output_hidden_states (`bool`, *optional*):
804
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
805
+ for more detail.
806
+ return_dict (`bool`, *optional*):
807
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
808
+ """
809
+
810
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
811
+ if input_features.shape[-1] != expected_seq_length:
812
+ raise ValueError(
813
+ f"Speech expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
814
+ )
815
+
816
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
817
+ output_hidden_states = (
818
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
819
+ )
820
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
821
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
822
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
823
+
824
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
825
+ embed_pos = self.embed_positions.weight
826
+
827
+ hidden_states = inputs_embeds + embed_pos
828
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
829
+
830
+ encoder_states = () if output_hidden_states else None
831
+ all_attentions = () if output_attentions else None
832
+
833
+ # check if head_mask has a correct number of layers specified if desired
834
+ if head_mask is not None:
835
+ assert head_mask.size()[0] == (
836
+ len(self.layers)
837
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
838
+
839
+ for idx, encoder_layer in enumerate(self.layers):
840
+ if output_hidden_states:
841
+ encoder_states = encoder_states + (hidden_states,)
842
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
843
+ to_drop = False
844
+ if self.training:
845
+ dropout_probability = torch.rand([])
846
+ if dropout_probability < self.layerdrop: # skip the layer
847
+ to_drop = True
848
+
849
+ if to_drop:
850
+ layer_outputs = (None, None)
851
+ else:
852
+ if self.gradient_checkpointing and self.training:
853
+ layer_outputs = self._gradient_checkpointing_func(
854
+ encoder_layer.__call__,
855
+ hidden_states,
856
+ None,
857
+ (head_mask[idx] if head_mask is not None else None),
858
+ output_attentions,
859
+ )
860
+ else:
861
+ layer_outputs = encoder_layer(
862
+ hidden_states,
863
+ None,
864
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
865
+ output_attentions=output_attentions,
866
+ )
867
+
868
+ hidden_states = layer_outputs[0]
869
+
870
+ if output_attentions:
871
+ all_attentions = all_attentions + (layer_outputs[1],)
872
+
873
+ hidden_states = self.layer_norm(hidden_states)
874
+ if output_hidden_states:
875
+ encoder_states = encoder_states + (hidden_states,)
876
+
877
+ if not return_dict:
878
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
879
+ return BaseModelOutput(
880
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
881
+ )
882
+
883
+
884
+ # copied from Qwen2AudioCausalLMOutputWithPast
885
+ @dataclass
886
+ class MERaLiONOutputWithPast(ModelOutput):
887
+ """
888
+ Base class for MERaLiON causal language model (or autoregressive) outputs.
889
+
890
+ Args:
891
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
892
+ Language modeling loss (for next-token prediction).
893
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
894
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
895
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
896
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
897
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
898
+
899
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
900
+ `past_key_values` input) to speed up sequential decoding.
901
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
902
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
903
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
904
+
905
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
906
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
907
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
908
+ sequence_length)`.
909
+
910
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
911
+ heads.
912
+ attention_mask (`torch.FloatTensor`, *optional*):
913
+ Attentions mask, used to update attention mask and position_ids.
914
+ """
915
+
916
+ loss: Optional[torch.FloatTensor] = None
917
+ logits: torch.FloatTensor = None
918
+ past_key_values: Optional[List[torch.FloatTensor]] = None
919
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
920
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
921
+ attention_mask: Optional[torch.FloatTensor] = None
922
+
923
+
924
+ MERALION_START_DOCSTRING = r"""
925
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
926
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
927
+ etc.)
928
+
929
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
930
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
931
+ and behavior.
932
+
933
+ Parameters:
934
+ config ([`MERaLiONConfig`]):
935
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
936
+ load the weights associated with the model, only the configuration. Check out the
937
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
938
+ """
939
+
940
+
941
+ @add_start_docstrings(
942
+ "The bare MERaLiON Model outputting raw hidden-states without any specific head on top.",
943
+ MERALION_START_DOCSTRING,
944
+ )
945
+ class MERaLiONPreTrainedModel(PreTrainedModel):
946
+ config_class = MERaLiONConfig
947
+ base_model_prefix = "model"
948
+ supports_gradient_checkpointing = True
949
+ _no_split_modules = ["MERaLiONSpeechEncoderLayer", "MERaLiONSpeechDecoderLayer", "MERaLiONTextDecoderLayer"]
950
+ _supports_flash_attn_2 = True
951
+ _supports_sdpa = True
952
+ _supports_cache_class = True
953
+ _supports_static_cache = True
954
+
955
+ def _init_weights(self, module):
956
+ # important: this ported version of Qwen2Audio isn't meant for training from scratch - only
957
+ # inference and fine-tuning - so the proper init weights code has been removed
958
+ std = self.config.init_std if hasattr(self.config, "init_std") else self.config.speech_config.init_std
959
+
960
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
961
+ module.weight.data.normal_(mean=0.0, std=std)
962
+ if module.bias is not None:
963
+ module.bias.data.zero_()
964
+ elif isinstance(module, nn.Embedding):
965
+ module.weight.data.normal_(mean=0.0, std=std)
966
+ if module.padding_idx is not None:
967
+ module.weight.data[module.padding_idx].zero_()
968
+
969
+ @property
970
+ def _supports_sdpa(self):
971
+ """
972
+ Retrieve language_model's attribute to check whether the model supports
973
+ SDPA or not.
974
+ """
975
+ return self.text_decoder._supports_sdpa
976
+
977
+ class MERaLiONSpeechAudioAdaper(nn.Module):
978
+ def __init__(
979
+ self,
980
+ config,
981
+ **kwargs
982
+ ):
983
+ super(MERaLiONSpeechAudioAdaper, self).__init__()
984
+ speech_audio_encoder_output_dim = config.speech_config.d_model
985
+ llm_input_hidden_size = config.text_config.hidden_size
986
+ speech_mlp_scale_factor = config.speech_mlp_scale_factor
987
+
988
+ self.speech_mlp_scale_factor = speech_mlp_scale_factor
989
+ self.mlp_adapter = nn.Sequential(
990
+ nn.Linear(
991
+ in_features=speech_audio_encoder_output_dim * speech_mlp_scale_factor,
992
+ out_features=speech_audio_encoder_output_dim
993
+ ),
994
+ nn.SiLU(),
995
+ nn.Dropout(0.1),
996
+ )
997
+
998
+ self.speech_llm_proj = nn.Sequential(
999
+ nn.Linear(
1000
+ speech_audio_encoder_output_dim,
1001
+ speech_audio_encoder_output_dim * 4
1002
+ ),
1003
+ nn.SiLU(),
1004
+ nn.Dropout(0.1),
1005
+
1006
+ nn.Linear(
1007
+ speech_audio_encoder_output_dim * 4,
1008
+ llm_input_hidden_size
1009
+ ),
1010
+ )
1011
+
1012
+ def forward(self, speech_embeds, **kwargs):
1013
+ B, T, C = speech_embeds.shape
1014
+ speech_embeds = self.mlp_adapter(
1015
+ speech_embeds.reshape(
1016
+ B,
1017
+ T // self.speech_mlp_scale_factor,
1018
+ C * self.speech_mlp_scale_factor,
1019
+ )
1020
+ )
1021
+ return self.speech_llm_proj(speech_embeds)
1022
+
1023
+
1024
+ MERALION_INPUTS_DOCSTRING = r"""
1025
+ Args:
1026
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1027
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1028
+ it.
1029
+
1030
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1031
+ [`PreTrainedTokenizer.__call__`] for details.
1032
+
1033
+ [What are input IDs?](../glossary#input-ids)
1034
+ input_ids_left (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1035
+ Indices of left-padded input sequences tokens in the vocabulary. Padding will be ignored by default should you provide
1036
+ it.
1037
+ input_ids_right (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1038
+ Indices of right-padded input sequences tokens in the vocabulary. Padding will be ignored by default should you provide
1039
+ it.
1040
+ input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`, *optional*):
1041
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
1042
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1043
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1044
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1045
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
1046
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1047
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1048
+
1049
+ - 1 for tokens that are **not masked**,
1050
+ - 0 for tokens that are **masked**.
1051
+
1052
+ [What are attention masks?](../glossary#attention-mask)
1053
+
1054
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1055
+ [`PreTrainedTokenizer.__call__`] for details.
1056
+
1057
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1058
+ `past_key_values`).
1059
+
1060
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1061
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1062
+ information on the default strategy.
1063
+
1064
+ - 1 indicates the head is **not masked**,
1065
+ - 0 indicates the head is **masked**.
1066
+
1067
+ attention_mask_left (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
1068
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
1069
+
1070
+ - 1 for tokens that are **not masked**,
1071
+ - 0 for tokens that are **masked**.
1072
+ attention_mask_right (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
1073
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
1074
+
1075
+ - 1 for tokens that are **not masked**,
1076
+ - 0 for tokens that are **masked**.
1077
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
1078
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
1079
+
1080
+ - 1 for tokens that are **not masked**,
1081
+ - 0 for tokens that are **masked**.
1082
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1083
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1084
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
1085
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1086
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1087
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1088
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1089
+
1090
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1091
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1092
+
1093
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1094
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1095
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1096
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1097
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1098
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1099
+ model's internal embedding lookup matrix.
1100
+ use_cache (`bool`, *optional*):
1101
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1102
+ `past_key_values`).
1103
+ output_attentions (`bool`, *optional*):
1104
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1105
+ tensors for more detail.
1106
+ output_hidden_states (`bool`, *optional*):
1107
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1108
+ more detail.
1109
+ return_dict (`bool`, *optional*):
1110
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1111
+ """
1112
+
1113
+ @add_start_docstrings(
1114
+ """The MERALION model which consists of a audio backbone and a language model.""",
1115
+ MERALION_START_DOCSTRING,
1116
+ )
1117
+ class MERaLiONForConditionalGeneration(MERaLiONPreTrainedModel, GenerationMixin):
1118
+ def __init__(self, config: MERaLiONConfig):
1119
+ config.text_config._attn_implementation = config._attn_implementation
1120
+ config.speech_config._attn_implementation = config._attn_implementation
1121
+
1122
+ super().__init__(config)
1123
+
1124
+ self.speech_encoder = MERaLiONSpeechEncoder(config.speech_config)
1125
+ # self.speech_encoder = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation)
1126
+
1127
+ self.ln_speech = nn.LayerNorm(config.speech_config.d_model)
1128
+ self.speech_audio_adapter = MERaLiONSpeechAudioAdaper(config)
1129
+ self.vocab_size = config.text_config.vocab_size
1130
+ self.text_decoder = MERaLiONTextForCausalLM(config.text_config)
1131
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
1132
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
1133
+ self.post_init()
1134
+
1135
+ @property
1136
+ def padding_side(self):
1137
+ return self._padding_side
1138
+
1139
+ @padding_side.setter
1140
+ def padding_side(self, padding_side: str):
1141
+ if padding_side not in ["left", "right"]:
1142
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
1143
+ self._padding_side = padding_side
1144
+
1145
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
1146
+ def get_input_embeddings(self):
1147
+ return self.text_decoder.get_input_embeddings()
1148
+
1149
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
1150
+ def set_input_embeddings(self, value):
1151
+ self.text_decoder.set_input_embeddings(value)
1152
+
1153
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
1154
+ def get_output_embeddings(self):
1155
+ return self.text_decoder.get_output_embeddings()
1156
+
1157
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
1158
+ def set_output_embeddings(self, new_embeddings):
1159
+ self.text_decoder.set_output_embeddings(new_embeddings)
1160
+
1161
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
1162
+ def set_decoder(self, decoder):
1163
+ self.text_decoder.set_decoder(decoder)
1164
+
1165
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
1166
+ def get_decoder(self):
1167
+ return self.text_decoder.get_decoder()
1168
+
1169
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
1170
+ def tie_weights(self):
1171
+ return self.text_decoder.tie_weights()
1172
+
1173
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
1174
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
1175
+ model_embeds = self.text_decoder.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1176
+ # update vocab size
1177
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
1178
+ self.vocab_size = model_embeds.num_embeddings
1179
+ return model_embeds
1180
+
1181
+ def _get_multimodal_input_embeds(
1182
+ self,
1183
+ input_ids_left,
1184
+ input_ids_right,
1185
+ attention_mask_left,
1186
+ attention_mask_right,
1187
+ speech_audio_contexts_embeds,
1188
+ speech_audio_contexts_atts,
1189
+ ):
1190
+ input_embeds_left = self.text_decoder.base_model.embed_tokens(input_ids_left)
1191
+ input_embeds_right = self.text_decoder.base_model.embed_tokens(input_ids_right)
1192
+
1193
+ multimodal_embeds = torch.cat(
1194
+ [
1195
+ input_embeds_left,
1196
+ speech_audio_contexts_embeds,
1197
+ input_embeds_right,
1198
+ ],
1199
+ dim=1,
1200
+ )
1201
+
1202
+ multimodal_attention_mask = torch.cat(
1203
+ [
1204
+ attention_mask_left,
1205
+ speech_audio_contexts_atts,
1206
+ attention_mask_right,
1207
+ ],
1208
+ dim=1,
1209
+ )
1210
+ return multimodal_embeds, multimodal_attention_mask
1211
+
1212
+ @add_start_docstrings_to_model_forward(MERALION_INPUTS_DOCSTRING)
1213
+ @replace_return_docstrings(output_type=MERaLiONOutputWithPast, config_class=_CONFIG_FOR_DOC)
1214
+ def forward(
1215
+ self,
1216
+ input_ids: torch.LongTensor = None,
1217
+ input_features: torch.FloatTensor = None,
1218
+ attention_mask: Optional[torch.Tensor] = None,
1219
+ feature_attention_mask: Optional[torch.Tensor] = None,
1220
+ position_ids: Optional[torch.LongTensor] = None,
1221
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1222
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1223
+ labels: Optional[torch.LongTensor] = None,
1224
+ use_cache: Optional[bool] = None,
1225
+ cache_position: Optional[torch.LongTensor] = None,
1226
+ output_attentions: Optional[bool] = None,
1227
+ output_hidden_states: Optional[bool] = None,
1228
+ return_dict: Optional[bool] = None,
1229
+ ) -> Union[Tuple, MERaLiONOutputWithPast]:
1230
+ r"""
1231
+ Args:
1232
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1233
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1234
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1235
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1236
+
1237
+ Returns:
1238
+ """
1239
+
1240
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1241
+ output_hidden_states = (
1242
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1243
+ )
1244
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1245
+
1246
+ speech_encoder_device = self.speech_encoder.device
1247
+
1248
+ if input_features is not None:
1249
+ input_features = input_features.to(speech_encoder_device)
1250
+ feature_attention_mask = feature_attention_mask.to(speech_encoder_device)
1251
+
1252
+ if inputs_embeds is None:
1253
+ speech_contexts_embeds = self.speech_encoder(input_features, attention_mask=feature_attention_mask).last_hidden_state
1254
+ speech_contexts_embeds = self.ln_speech(speech_contexts_embeds)
1255
+ speech_audio_contexts_embeds = self.speech_audio_adapter(speech_contexts_embeds)
1256
+
1257
+ inputs_embeds = self.text_decoder.base_model.embed_tokens(input_ids)
1258
+
1259
+ speech_mask = (input_ids == self.config.speech_token_index).unsqueeze(-1)
1260
+ speech_mask = speech_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
1261
+
1262
+ inputs_embeds = inputs_embeds.masked_scatter(speech_mask, speech_audio_contexts_embeds)
1263
+
1264
+ input_ids = None
1265
+
1266
+ outputs = self.text_decoder(
1267
+ input_ids=input_ids,
1268
+ attention_mask=attention_mask,
1269
+ position_ids=position_ids,
1270
+ past_key_values=past_key_values,
1271
+ inputs_embeds=inputs_embeds,
1272
+ use_cache=use_cache,
1273
+ cache_position=cache_position,
1274
+ output_attentions=output_attentions,
1275
+ output_hidden_states=output_hidden_states,
1276
+ return_dict=return_dict,
1277
+ labels=labels
1278
+ )
1279
+
1280
+ return outputs
1281
+
1282
+ # from transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM.prepare_inputs_for_generation
1283
+ def prepare_inputs_for_generation(
1284
+ self,
1285
+ input_ids,
1286
+ attention_mask=None,
1287
+ input_features=None,
1288
+ feature_attention_mask=None,
1289
+ past_key_values=None,
1290
+ inputs_embeds=None,
1291
+ cache_position=None,
1292
+ position_ids=None,
1293
+ use_cache=None,
1294
+ **kwargs,
1295
+ ):
1296
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1297
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1298
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1299
+ is_first_step = cache_position[0].item() == 0
1300
+ if past_key_values is not None:
1301
+ if inputs_embeds is not None: # Exception 1
1302
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1303
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1304
+ input_ids = input_ids[:, cache_position]
1305
+
1306
+ if attention_mask is not None and position_ids is None:
1307
+ # create position_ids on the fly for batch generation
1308
+ position_ids = attention_mask.long().cumsum(-1) - 1
1309
+ position_ids.masked_fill_(attention_mask == 0, 1)
1310
+ if past_key_values:
1311
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1312
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
1313
+ # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
1314
+ # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
1315
+ # batch size = 1 case, `position_ids` is already contiguous but with varying stride
1316
+ # which retriggers a capture.
1317
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1318
+
1319
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1320
+ if inputs_embeds is not None and is_first_step:
1321
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1322
+ else:
1323
+ # The clone here is for the same reason as for `position_ids`.
1324
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1325
+
1326
+ if (
1327
+ isinstance(past_key_values, HybridCache)
1328
+ and attention_mask.ndim == 2
1329
+ and not self.config._attn_implementation == "flash_attention_2"
1330
+ ):
1331
+ if model_inputs["inputs_embeds"] is not None:
1332
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1333
+ device = model_inputs["inputs_embeds"].device
1334
+ else:
1335
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1336
+ device = model_inputs["input_ids"].device
1337
+ dtype = self.text_decoder.lm_head.weight.dtype
1338
+ min_dtype = torch.finfo(dtype).min
1339
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1340
+ attention_mask,
1341
+ sequence_length=sequence_length,
1342
+ target_length=past_key_values.get_max_length(),
1343
+ dtype=dtype,
1344
+ device=device,
1345
+ min_dtype=min_dtype,
1346
+ cache_position=cache_position,
1347
+ batch_size=batch_size,
1348
+ )
1349
+
1350
+ model_inputs.update(
1351
+ {
1352
+ "attention_mask": attention_mask,
1353
+ "position_ids": position_ids,
1354
+ "cache_position": cache_position,
1355
+ "past_key_values": past_key_values,
1356
+ "use_cache": use_cache
1357
+ }
1358
+ )
1359
+
1360
+ # Input ids will only be used from the second step.
1361
+ if is_first_step:
1362
+ model_inputs["input_features"] = input_features
1363
+ model_inputs["feature_attention_mask"] = feature_attention_mask
1364
+
1365
+ return model_inputs
1366
+
1367
+ def _reorder_cache(self, *args, **kwargs):
1368
+ return self.text_decoder._reorder_cache(*args, **kwargs)