yangwang825 commited on
Commit
9f52360
1 Parent(s): 17feb1f

Create modeling_audio_spectrogram_transformer.py

Browse files
modeling_audio_spectrogram_transformer.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Audio Spectrogram Transformer (AST) model."""
16
+
17
+ import math
18
+ from typing import Dict, List, Optional, Set, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
29
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
30
+ from .configuration_audio_spectrogram_transformer import ASTConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ # General docstring
36
+ _CONFIG_FOR_DOC = "ASTConfig"
37
+
38
+ # Base docstring
39
+ _CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
40
+ _EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
41
+
42
+ # Audio classification docstring
43
+ _SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
44
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
45
+ _SEQ_CLASS_EXPECTED_LOSS = 0.17
46
+
47
+
48
+ class ASTEmbeddings(nn.Module):
49
+ """
50
+ Construct the CLS token, position and patch embeddings.
51
+ """
52
+
53
+ def __init__(self, config: ASTConfig) -> None:
54
+ super().__init__()
55
+
56
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
57
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
58
+ self.patch_embeddings = ASTPatchEmbeddings(config)
59
+
60
+ frequency_out_dimension, time_out_dimension = self.get_shape(config)
61
+ num_patches = frequency_out_dimension * time_out_dimension
62
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+ self.config = config
65
+
66
+ def get_shape(self, config):
67
+ # see Karpathy's cs231n blog on how to calculate the output dimensions
68
+ # https://cs231n.github.io/convolutional-networks/#conv
69
+ if config.frequency_patch_size is not None:
70
+ frequency_out_dimension = (config.num_mel_bins - config.frequency_patch_size) // config.frequency_stride + 1
71
+ else:
72
+ frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
73
+ if config.time_patch_size is not None:
74
+ time_out_dimension = (config.max_length - config.time_patch_size) // config.time_stride + 1
75
+ else:
76
+ time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
77
+
78
+ return frequency_out_dimension, time_out_dimension
79
+
80
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
81
+ batch_size = input_values.shape[0]
82
+ embeddings = self.patch_embeddings(input_values)
83
+
84
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
85
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
86
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
87
+ embeddings = embeddings + self.position_embeddings
88
+ embeddings = self.dropout(embeddings)
89
+
90
+ return embeddings
91
+
92
+
93
+ class ASTPatchEmbeddings(nn.Module):
94
+ """
95
+ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
96
+ seq_length, hidden_size)` to be consumed by a Transformer.
97
+ """
98
+
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ if config.frequency_patch_size is not None and config.time_patch_size is not None:
102
+ kernel_size = (config.frequency_patch_size, config.time_patch_size)
103
+ else:
104
+ kernel_size = (config.patch_size, config.patch_size)
105
+ frequency_stride = config.frequency_stride
106
+ time_stride = config.time_stride
107
+
108
+ self.projection = nn.Conv2d(
109
+ 1, config.hidden_size, kernel_size=kernel_size, stride=(frequency_stride, time_stride)
110
+ )
111
+
112
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
113
+ input_values = input_values.unsqueeze(1)
114
+ input_values = input_values.transpose(2, 3)
115
+ embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
116
+ return embeddings
117
+
118
+
119
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
120
+ class ASTSelfAttention(nn.Module):
121
+ def __init__(self, config: ASTConfig) -> None:
122
+ super().__init__()
123
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
124
+ raise ValueError(
125
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
126
+ f"heads {config.num_attention_heads}."
127
+ )
128
+
129
+ self.num_attention_heads = config.num_attention_heads
130
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
131
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
132
+
133
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
134
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
135
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
136
+
137
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
138
+
139
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
140
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
141
+ x = x.view(new_x_shape)
142
+ return x.permute(0, 2, 1, 3)
143
+
144
+ def forward(
145
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
146
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
147
+ mixed_query_layer = self.query(hidden_states)
148
+
149
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
150
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
151
+ query_layer = self.transpose_for_scores(mixed_query_layer)
152
+
153
+ # Take the dot product between "query" and "key" to get the raw attention scores.
154
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
155
+
156
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
157
+
158
+ # Normalize the attention scores to probabilities.
159
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
160
+
161
+ # This is actually dropping out entire tokens to attend to, which might
162
+ # seem a bit unusual, but is taken from the original Transformer paper.
163
+ attention_probs = self.dropout(attention_probs)
164
+
165
+ # Mask heads if we want to
166
+ if head_mask is not None:
167
+ attention_probs = attention_probs * head_mask
168
+
169
+ context_layer = torch.matmul(attention_probs, value_layer)
170
+
171
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
172
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
173
+ context_layer = context_layer.view(new_context_layer_shape)
174
+
175
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
176
+
177
+ return outputs
178
+
179
+
180
+ # Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
181
+ class ASTSdpaSelfAttention(ASTSelfAttention):
182
+ def __init__(self, config: ASTConfig) -> None:
183
+ super().__init__(config)
184
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
185
+
186
+ def forward(
187
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
188
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
189
+ mixed_query_layer = self.query(hidden_states)
190
+
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ query_layer = self.transpose_for_scores(mixed_query_layer)
194
+
195
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
196
+ query_layer,
197
+ key_layer,
198
+ value_layer,
199
+ head_mask,
200
+ self.attention_probs_dropout_prob if self.training else 0.0,
201
+ is_causal=False,
202
+ scale=None,
203
+ )
204
+
205
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
206
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
207
+ context_layer = context_layer.view(new_context_layer_shape)
208
+
209
+ return context_layer, None
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
213
+ class ASTSelfOutput(nn.Module):
214
+ """
215
+ The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
216
+ layernorm applied before each block.
217
+ """
218
+
219
+ def __init__(self, config: ASTConfig) -> None:
220
+ super().__init__()
221
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+
224
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
+ hidden_states = self.dense(hidden_states)
226
+ hidden_states = self.dropout(hidden_states)
227
+
228
+ return hidden_states
229
+
230
+
231
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
232
+ class ASTAttention(nn.Module):
233
+ def __init__(self, config: ASTConfig) -> None:
234
+ super().__init__()
235
+ self.attention = ASTSelfAttention(config)
236
+ self.output = ASTSelfOutput(config)
237
+ self.pruned_heads = set()
238
+
239
+ def prune_heads(self, heads: Set[int]) -> None:
240
+ if len(heads) == 0:
241
+ return
242
+ heads, index = find_pruneable_heads_and_indices(
243
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
244
+ )
245
+
246
+ # Prune linear layers
247
+ self.attention.query = prune_linear_layer(self.attention.query, index)
248
+ self.attention.key = prune_linear_layer(self.attention.key, index)
249
+ self.attention.value = prune_linear_layer(self.attention.value, index)
250
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
251
+
252
+ # Update hyper params and store pruned heads
253
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
254
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
255
+ self.pruned_heads = self.pruned_heads.union(heads)
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ head_mask: Optional[torch.Tensor] = None,
261
+ output_attentions: bool = False,
262
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
263
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
264
+
265
+ attention_output = self.output(self_outputs[0], hidden_states)
266
+
267
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
268
+ return outputs
269
+
270
+
271
+ # Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
272
+ class ASTSdpaAttention(ASTAttention):
273
+ def __init__(self, config: ASTConfig) -> None:
274
+ super().__init__(config)
275
+ self.attention = ASTSdpaSelfAttention(config)
276
+
277
+
278
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
279
+ class ASTIntermediate(nn.Module):
280
+ def __init__(self, config: ASTConfig) -> None:
281
+ super().__init__()
282
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
283
+ if isinstance(config.hidden_act, str):
284
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
285
+ else:
286
+ self.intermediate_act_fn = config.hidden_act
287
+
288
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
289
+ hidden_states = self.dense(hidden_states)
290
+ hidden_states = self.intermediate_act_fn(hidden_states)
291
+
292
+ return hidden_states
293
+
294
+
295
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
296
+ class ASTOutput(nn.Module):
297
+ def __init__(self, config: ASTConfig) -> None:
298
+ super().__init__()
299
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
300
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
301
+
302
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
303
+ hidden_states = self.dense(hidden_states)
304
+ hidden_states = self.dropout(hidden_states)
305
+
306
+ hidden_states = hidden_states + input_tensor
307
+
308
+ return hidden_states
309
+
310
+
311
+ AST_ATTENTION_CLASSES = {
312
+ "eager": ASTAttention,
313
+ "sdpa": ASTSdpaAttention,
314
+ }
315
+
316
+
317
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
318
+ class ASTLayer(nn.Module):
319
+ """This corresponds to the Block class in the timm implementation."""
320
+
321
+ def __init__(self, config: ASTConfig) -> None:
322
+ super().__init__()
323
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
324
+ self.seq_len_dim = 1
325
+ self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
326
+ self.intermediate = ASTIntermediate(config)
327
+ self.output = ASTOutput(config)
328
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
329
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
330
+
331
+ def forward(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ head_mask: Optional[torch.Tensor] = None,
335
+ output_attentions: bool = False,
336
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
337
+ self_attention_outputs = self.attention(
338
+ self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention
339
+ head_mask,
340
+ output_attentions=output_attentions,
341
+ )
342
+ attention_output = self_attention_outputs[0]
343
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
344
+
345
+ # first residual connection
346
+ hidden_states = attention_output + hidden_states
347
+
348
+ # in AST, layernorm is also applied after self-attention
349
+ layer_output = self.layernorm_after(hidden_states)
350
+ layer_output = self.intermediate(layer_output)
351
+
352
+ # second residual connection is done here
353
+ layer_output = self.output(layer_output, hidden_states)
354
+
355
+ outputs = (layer_output,) + outputs
356
+
357
+ return outputs
358
+
359
+
360
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
361
+ class ASTEncoder(nn.Module):
362
+ def __init__(self, config: ASTConfig) -> None:
363
+ super().__init__()
364
+ self.config = config
365
+ self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
366
+ self.gradient_checkpointing = False
367
+
368
+ def forward(
369
+ self,
370
+ hidden_states: torch.Tensor,
371
+ head_mask: Optional[torch.Tensor] = None,
372
+ output_attentions: bool = False,
373
+ output_hidden_states: bool = False,
374
+ return_dict: bool = True,
375
+ ) -> Union[tuple, BaseModelOutput]:
376
+ all_hidden_states = () if output_hidden_states else None
377
+ all_self_attentions = () if output_attentions else None
378
+
379
+ for i, layer_module in enumerate(self.layer):
380
+ if output_hidden_states:
381
+ all_hidden_states = all_hidden_states + (hidden_states,)
382
+
383
+ layer_head_mask = head_mask[i] if head_mask is not None else None
384
+
385
+ if self.gradient_checkpointing and self.training:
386
+ layer_outputs = self._gradient_checkpointing_func(
387
+ layer_module.__call__,
388
+ hidden_states,
389
+ layer_head_mask,
390
+ output_attentions,
391
+ )
392
+ else:
393
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
394
+
395
+ hidden_states = layer_outputs[0]
396
+
397
+ if output_attentions:
398
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
399
+
400
+ if output_hidden_states:
401
+ all_hidden_states = all_hidden_states + (hidden_states,)
402
+
403
+ if not return_dict:
404
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
405
+ return BaseModelOutput(
406
+ last_hidden_state=hidden_states,
407
+ hidden_states=all_hidden_states,
408
+ attentions=all_self_attentions,
409
+ )
410
+
411
+
412
+ class ASTPreTrainedModel(PreTrainedModel):
413
+ """
414
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
415
+ models.
416
+ """
417
+
418
+ config_class = ASTConfig
419
+ base_model_prefix = "audio_spectrogram_transformer"
420
+ main_input_name = "input_values"
421
+ supports_gradient_checkpointing = True
422
+ _supports_sdpa = True
423
+
424
+ # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
425
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
426
+ """Initialize the weights"""
427
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
428
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
429
+ # `trunc_normal_cpu` not implemented in `half` issues
430
+ module.weight.data = nn.init.trunc_normal_(
431
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
432
+ ).to(module.weight.dtype)
433
+ if module.bias is not None:
434
+ module.bias.data.zero_()
435
+ elif isinstance(module, nn.LayerNorm):
436
+ module.bias.data.zero_()
437
+ module.weight.data.fill_(1.0)
438
+
439
+
440
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
441
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
442
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
443
+ behavior.
444
+
445
+ Parameters:
446
+ config ([`ASTConfig`]):
447
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
448
+ load the weights associated with the model, only the configuration. Check out the
449
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
450
+ """
451
+
452
+ AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
453
+ Args:
454
+ input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`):
455
+ Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
456
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
457
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
458
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
459
+ tensor of type `torch.FloatTensor`. See [`~ASTFeatureExtractor.__call__`]
460
+
461
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
462
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
463
+
464
+ - 1 indicates the head is **not masked**,
465
+ - 0 indicates the head is **masked**.
466
+
467
+ output_attentions (`bool`, *optional*):
468
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
469
+ tensors for more detail.
470
+ output_hidden_states (`bool`, *optional*):
471
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
472
+ more detail.
473
+ return_dict (`bool`, *optional*):
474
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
475
+ """
476
+
477
+
478
+ @add_start_docstrings(
479
+ "The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
480
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
481
+ )
482
+ class ASTModel(ASTPreTrainedModel):
483
+ def __init__(self, config: ASTConfig) -> None:
484
+ super().__init__(config)
485
+ self.config = config
486
+
487
+ self.embeddings = ASTEmbeddings(config)
488
+ self.encoder = ASTEncoder(config)
489
+
490
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
491
+
492
+ # Initialize weights and apply final processing
493
+ self.post_init()
494
+
495
+ def get_input_embeddings(self) -> ASTPatchEmbeddings:
496
+ return self.embeddings.patch_embeddings
497
+
498
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
499
+ """
500
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
501
+ class PreTrainedModel
502
+ """
503
+ for layer, heads in heads_to_prune.items():
504
+ self.encoder.layer[layer].attention.prune_heads(heads)
505
+
506
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
507
+ @add_code_sample_docstrings(
508
+ checkpoint=_CHECKPOINT_FOR_DOC,
509
+ output_type=BaseModelOutputWithPooling,
510
+ config_class=_CONFIG_FOR_DOC,
511
+ modality="audio",
512
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
513
+ )
514
+ def forward(
515
+ self,
516
+ input_values: Optional[torch.Tensor] = None,
517
+ head_mask: Optional[torch.Tensor] = None,
518
+ output_attentions: Optional[bool] = None,
519
+ output_hidden_states: Optional[bool] = None,
520
+ return_dict: Optional[bool] = None,
521
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
522
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
523
+ output_hidden_states = (
524
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
525
+ )
526
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
527
+
528
+ if input_values is None:
529
+ raise ValueError("You have to specify input_values")
530
+
531
+ # Prepare head mask if needed
532
+ # 1.0 in head_mask indicate we keep the head
533
+ # attention_probs has shape bsz x n_heads x N x N
534
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
535
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
536
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
537
+
538
+ embedding_output = self.embeddings(input_values)
539
+
540
+ encoder_outputs = self.encoder(
541
+ embedding_output,
542
+ head_mask=head_mask,
543
+ output_attentions=output_attentions,
544
+ output_hidden_states=output_hidden_states,
545
+ return_dict=return_dict,
546
+ )
547
+ sequence_output = encoder_outputs[0]
548
+ sequence_output = self.layernorm(sequence_output)
549
+
550
+ pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
551
+
552
+ if not return_dict:
553
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
554
+
555
+ return BaseModelOutputWithPooling(
556
+ last_hidden_state=sequence_output,
557
+ pooler_output=pooled_output,
558
+ hidden_states=encoder_outputs.hidden_states,
559
+ attentions=encoder_outputs.attentions,
560
+ )
561
+
562
+
563
+ class ASTMLPHead(nn.Module):
564
+ def __init__(self, config: ASTConfig):
565
+ super().__init__()
566
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
567
+ self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
568
+
569
+ def forward(self, hidden_state):
570
+ hidden_state = self.layernorm(hidden_state)
571
+ hidden_state = self.dense(hidden_state)
572
+ return hidden_state
573
+
574
+
575
+ @add_start_docstrings(
576
+ """
577
+ Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
578
+ output) e.g. for datasets like AudioSet, Speech Commands v2.
579
+ """,
580
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
581
+ )
582
+ class ASTForAudioClassification(ASTPreTrainedModel):
583
+ def __init__(self, config: ASTConfig) -> None:
584
+ super().__init__(config)
585
+
586
+ self.num_labels = config.num_labels
587
+ self.audio_spectrogram_transformer = ASTModel(config)
588
+
589
+ # Classifier head
590
+ self.classifier = ASTMLPHead(config)
591
+
592
+ # Initialize weights and apply final processing
593
+ self.post_init()
594
+
595
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
596
+ @add_code_sample_docstrings(
597
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
598
+ output_type=SequenceClassifierOutput,
599
+ config_class=_CONFIG_FOR_DOC,
600
+ modality="audio",
601
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
602
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
603
+ )
604
+ def forward(
605
+ self,
606
+ input_values: Optional[torch.Tensor] = None,
607
+ head_mask: Optional[torch.Tensor] = None,
608
+ labels: Optional[torch.Tensor] = None,
609
+ output_attentions: Optional[bool] = None,
610
+ output_hidden_states: Optional[bool] = None,
611
+ return_dict: Optional[bool] = None,
612
+ ) -> Union[tuple, SequenceClassifierOutput]:
613
+ r"""
614
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
615
+ Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
616
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
617
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
618
+ """
619
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
620
+
621
+ outputs = self.audio_spectrogram_transformer(
622
+ input_values,
623
+ head_mask=head_mask,
624
+ output_attentions=output_attentions,
625
+ output_hidden_states=output_hidden_states,
626
+ return_dict=return_dict,
627
+ )
628
+
629
+ pooled_output = outputs[1]
630
+ logits = self.classifier(pooled_output)
631
+
632
+ loss = None
633
+ if labels is not None:
634
+ if self.config.problem_type is None:
635
+ if self.num_labels == 1:
636
+ self.config.problem_type = "regression"
637
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
638
+ self.config.problem_type = "single_label_classification"
639
+ else:
640
+ self.config.problem_type = "multi_label_classification"
641
+
642
+ if self.config.problem_type == "regression":
643
+ loss_fct = MSELoss()
644
+ if self.num_labels == 1:
645
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
646
+ else:
647
+ loss = loss_fct(logits, labels)
648
+ elif self.config.problem_type == "single_label_classification":
649
+ loss_fct = CrossEntropyLoss()
650
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
651
+ elif self.config.problem_type == "multi_label_classification":
652
+ loss_fct = BCEWithLogitsLoss()
653
+ loss = loss_fct(logits, labels)
654
+
655
+ if not return_dict:
656
+ output = (logits,) + outputs[2:]
657
+ return ((loss,) + output) if loss is not None else output
658
+
659
+ return SequenceClassifierOutput(
660
+ loss=loss,
661
+ logits=logits,
662
+ hidden_states=outputs.hidden_states,
663
+ attentions=outputs.attentions,
664
+ )