yangwang825 commited on
Commit
214b742
·
verified ·
1 Parent(s): 310fbde

Upload model

Browse files
Files changed (3) hide show
  1. config.json +3 -2
  2. model.safetensors +3 -0
  3. modeling_whisper_spkreg.py +640 -0
config.json CHANGED
@@ -4,11 +4,12 @@
4
  "activation_function": "gelu",
5
  "apply_spec_augment": false,
6
  "architectures": [
7
- "WhisperForConditionalGeneration"
8
  ],
9
  "attention_dropout": 0.0,
10
  "auto_map": {
11
- "AutoConfig": "configuration_whisper_spkreg.WhisperSpkRegConfig"
 
12
  },
13
  "begin_suppress_tokens": [
14
  220,
 
4
  "activation_function": "gelu",
5
  "apply_spec_augment": false,
6
  "architectures": [
7
+ "WhisperSpkRegModel"
8
  ],
9
  "attention_dropout": 0.0,
10
  "auto_map": {
11
+ "AutoConfig": "configuration_whisper_spkreg.WhisperSpkRegConfig",
12
+ "AutoModel": "modeling_whisper_spkreg.WhisperSpkRegModel"
13
  },
14
  "begin_suppress_tokens": [
15
  220,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93063179f1bbb1d278906e4a0c4adb3615ffcbe872ae23166d0fd9b0611ea1df
3
+ size 290402464
modeling_whisper_spkreg.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import Union, Tuple, Optional
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.modeling_outputs import (
13
+ SequenceClassifierOutput,
14
+ Wav2Vec2BaseModelOutput,
15
+ Seq2SeqModelOutput,
16
+ BaseModelOutput
17
+ )
18
+ from transformers.cache_utils import (
19
+ Cache,
20
+ DynamicCache,
21
+ EncoderDecoderCache,
22
+ StaticCache
23
+ )
24
+ from transformers.models.whisper.modeling_whisper import (
25
+ WhisperEncoder,
26
+ WhisperEncoderLayer,
27
+ WhisperDecoderLayer,
28
+ WhisperDecoder,
29
+ _HIDDEN_STATES_START_POSITION
30
+ )
31
+
32
+ from .configuration_whisper_spkreg import WhisperSpkRegConfig
33
+
34
+
35
+ def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
36
+ """Returns sinusoids for positional embedding"""
37
+ if channels % 2 != 0:
38
+ raise ValueError(
39
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
40
+ )
41
+ log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
42
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
43
+ scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
44
+ return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
45
+
46
+
47
+ def _compute_mask_indices(
48
+ shape: Tuple[int, int],
49
+ mask_prob: float,
50
+ mask_length: int,
51
+ attention_mask: Optional[torch.LongTensor] = None,
52
+ min_masks: int = 0,
53
+ ) -> np.ndarray:
54
+ """
55
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
56
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
57
+ CPU as part of the preprocessing during training.
58
+
59
+ Args:
60
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
61
+ the first element is the batch size and the second element is the length of the axis to span.
62
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
63
+ independently generated mask spans of length `mask_length` is computed by
64
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
65
+ actual percentage will be smaller.
66
+ mask_length: size of the mask
67
+ min_masks: minimum number of masked spans
68
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
69
+ each batch dimension.
70
+ """
71
+ batch_size, sequence_length = shape
72
+
73
+ if mask_length < 1:
74
+ raise ValueError("`mask_length` has to be bigger than 0.")
75
+
76
+ if mask_length > sequence_length:
77
+ raise ValueError(
78
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
79
+ f" and `sequence_length`: {sequence_length}`"
80
+ )
81
+
82
+ # epsilon is used for probabilistic rounding
83
+ epsilon = np.random.rand(1).item()
84
+
85
+ def compute_num_masked_span(input_length):
86
+ """Given input length, compute how many spans should be masked"""
87
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
88
+ num_masked_span = max(num_masked_span, min_masks)
89
+
90
+ # make sure num masked span <= sequence_length
91
+ if num_masked_span * mask_length > sequence_length:
92
+ num_masked_span = sequence_length // mask_length
93
+
94
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
95
+ if input_length - (mask_length - 1) < num_masked_span:
96
+ num_masked_span = max(input_length - (mask_length - 1), 0)
97
+
98
+ return num_masked_span
99
+
100
+ # compute number of masked spans in batch
101
+ input_lengths = (
102
+ attention_mask.sum(-1).detach().tolist()
103
+ if attention_mask is not None
104
+ else [sequence_length for _ in range(batch_size)]
105
+ )
106
+
107
+ # SpecAugment mask to fill
108
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
109
+ spec_aug_mask_idxs = []
110
+
111
+ max_num_masked_span = compute_num_masked_span(sequence_length)
112
+
113
+ if max_num_masked_span == 0:
114
+ return spec_aug_mask
115
+
116
+ for input_length in input_lengths:
117
+ # compute num of masked spans for this input
118
+ num_masked_span = compute_num_masked_span(input_length)
119
+
120
+ # get random indices to mask
121
+ spec_aug_mask_idx = np.random.choice(
122
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
123
+ )
124
+
125
+ # pick first sampled index that will serve as a dummy index to pad vector
126
+ # to ensure same dimension for all batches due to probabilistic rounding
127
+ # Picking first sample just pads those vectors twice.
128
+ if len(spec_aug_mask_idx) == 0:
129
+ # this case can only happen if `input_length` is strictly smaller then
130
+ # `sequence_length` in which case the last token has to be a padding
131
+ # token which we can use as a dummy mask id
132
+ dummy_mask_idx = sequence_length - 1
133
+ else:
134
+ dummy_mask_idx = spec_aug_mask_idx[0]
135
+
136
+ spec_aug_mask_idx = np.concatenate(
137
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
138
+ )
139
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
140
+
141
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
142
+
143
+ # expand masked indices to masked spans
144
+ spec_aug_mask_idxs = np.broadcast_to(
145
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
146
+ )
147
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
148
+
149
+ # add offset to the starting indexes so that indexes now create a span
150
+ offsets = np.arange(mask_length)[None, None, :]
151
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
152
+ batch_size, max_num_masked_span * mask_length
153
+ )
154
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
155
+
156
+ # ensure that we cannot have indices larger than sequence_length
157
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
158
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
159
+
160
+ # scatter indices to mask
161
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
162
+
163
+ return spec_aug_mask
164
+
165
+
166
+ class WhisperSpkRegPreTrainedModel(PreTrainedModel):
167
+
168
+ config_class = WhisperSpkRegConfig
169
+ base_model_prefix = "model"
170
+ main_input_name = "input_features"
171
+ supports_gradient_checkpointing = True
172
+ _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
173
+ _supports_flash_attn_2 = True
174
+ _supports_sdpa = True
175
+ _supports_cache_class = True
176
+ _supports_static_cache = True
177
+
178
+ def _init_weights(self, module):
179
+ std = self.config.init_std
180
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
181
+ module.weight.data.normal_(mean=0.0, std=std)
182
+ if module.bias is not None:
183
+ module.bias.data.zero_()
184
+ elif isinstance(module, nn.Embedding):
185
+ module.weight.data.normal_(mean=0.0, std=std)
186
+ if module.padding_idx is not None:
187
+ module.weight.data[module.padding_idx].zero_()
188
+ elif isinstance(module, WhisperEncoder):
189
+ with torch.no_grad():
190
+ embed_positions = module.embed_positions.weight
191
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
192
+
193
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
194
+ """
195
+ Computes the output length of the convolutional layers
196
+ """
197
+ input_lengths = (input_lengths - 1) // 2 + 1
198
+
199
+ return input_lengths
200
+
201
+
202
+ class WhisperSpkRegModel(WhisperSpkRegPreTrainedModel):
203
+
204
+ def __init__(self, config: WhisperSpkRegConfig):
205
+ super().__init__(config)
206
+
207
+ self.encoder = WhisperEncoder(config)
208
+ self.decoder = WhisperDecoder(config)
209
+ # Initialize weights and apply final processing
210
+ self.post_init()
211
+
212
+ def get_input_embeddings(self):
213
+ return self.decoder.embed_tokens
214
+
215
+ def set_input_embeddings(self, value):
216
+ self.decoder.embed_tokens = value
217
+
218
+ def get_encoder(self):
219
+ return self.encoder
220
+
221
+ def get_decoder(self):
222
+ return self.decoder
223
+
224
+ def freeze_encoder(self):
225
+ """
226
+ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
227
+ not be updated during training.
228
+ """
229
+ self.encoder._freeze_parameters()
230
+
231
+ def _mask_input_features(
232
+ self,
233
+ input_features: torch.FloatTensor,
234
+ attention_mask: Optional[torch.LongTensor] = None,
235
+ ):
236
+ """
237
+ Masks extracted features along time axis and/or along feature axis according to
238
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
239
+ """
240
+
241
+ # `config.apply_spec_augment` can set masking to False
242
+ if not getattr(self.config, "apply_spec_augment", True):
243
+ return input_features
244
+
245
+ # generate indices & apply SpecAugment along time axis
246
+ batch_size, hidden_size, sequence_length = input_features.size()
247
+
248
+ if self.config.mask_time_prob > 0 and self.training:
249
+ # generate indices & apply SpecAugment along time axis
250
+ mask_time_indices = _compute_mask_indices(
251
+ (batch_size, sequence_length),
252
+ mask_prob=self.config.mask_time_prob,
253
+ mask_length=self.config.mask_time_length,
254
+ attention_mask=attention_mask,
255
+ min_masks=self.config.mask_time_min_masks,
256
+ )
257
+ mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
258
+ mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
259
+ input_features[mask_time_indices] = 0
260
+
261
+ if self.config.mask_feature_prob > 0 and self.training:
262
+ # generate indices & apply SpecAugment along feature axis
263
+ mask_feature_indices = _compute_mask_indices(
264
+ (batch_size, hidden_size),
265
+ mask_prob=self.config.mask_feature_prob,
266
+ mask_length=self.config.mask_feature_length,
267
+ min_masks=self.config.mask_feature_min_masks,
268
+ )
269
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
270
+ input_features[mask_feature_indices] = 0
271
+
272
+ return input_features
273
+
274
+ def forward(
275
+ self,
276
+ input_features: Optional[torch.FloatTensor] = None,
277
+ attention_mask: Optional[torch.LongTensor] = None,
278
+ decoder_input_ids: Optional[torch.LongTensor] = None,
279
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
280
+ head_mask: Optional[torch.Tensor] = None,
281
+ decoder_head_mask: Optional[torch.Tensor] = None,
282
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
283
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
284
+ past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
285
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
286
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
287
+ use_cache: Optional[bool] = None,
288
+ output_attentions: Optional[bool] = None,
289
+ output_hidden_states: Optional[bool] = None,
290
+ return_dict: Optional[bool] = None,
291
+ cache_position: Optional[torch.LongTensor] = None,
292
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
293
+ r"""
294
+ Returns:
295
+
296
+ Example:
297
+ ```python
298
+ >>> import torch
299
+ >>> from transformers import AutoFeatureExtractor, WhisperModel
300
+ >>> from datasets import load_dataset
301
+
302
+ >>> model = WhisperModel.from_pretrained("openai/whisper-base")
303
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
304
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
305
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
306
+ >>> input_features = inputs.input_features
307
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
308
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
309
+ >>> list(last_hidden_state.shape)
310
+ [1, 2, 512]
311
+ ```"""
312
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
313
+ output_hidden_states = (
314
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
315
+ )
316
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
317
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318
+
319
+ if encoder_outputs is None:
320
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
321
+
322
+ encoder_outputs = self.encoder(
323
+ input_features,
324
+ head_mask=head_mask,
325
+ output_attentions=output_attentions,
326
+ output_hidden_states=output_hidden_states,
327
+ return_dict=return_dict,
328
+ )
329
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
330
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
331
+ encoder_outputs = BaseModelOutput(
332
+ last_hidden_state=encoder_outputs[0],
333
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
334
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
335
+ )
336
+
337
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
338
+ decoder_outputs = self.decoder(
339
+ input_ids=decoder_input_ids,
340
+ attention_mask=decoder_attention_mask,
341
+ encoder_hidden_states=encoder_outputs[0],
342
+ head_mask=decoder_head_mask,
343
+ cross_attn_head_mask=cross_attn_head_mask,
344
+ past_key_values=past_key_values,
345
+ inputs_embeds=decoder_inputs_embeds,
346
+ position_ids=decoder_position_ids,
347
+ use_cache=use_cache,
348
+ output_attentions=output_attentions,
349
+ output_hidden_states=output_hidden_states,
350
+ return_dict=return_dict,
351
+ cache_position=cache_position,
352
+ )
353
+
354
+ if not return_dict:
355
+ return decoder_outputs + encoder_outputs
356
+
357
+ return Seq2SeqModelOutput(
358
+ last_hidden_state=decoder_outputs.last_hidden_state,
359
+ past_key_values=decoder_outputs.past_key_values,
360
+ decoder_hidden_states=decoder_outputs.hidden_states,
361
+ decoder_attentions=decoder_outputs.attentions,
362
+ cross_attentions=decoder_outputs.cross_attentions,
363
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
364
+ encoder_hidden_states=encoder_outputs.hidden_states,
365
+ encoder_attentions=encoder_outputs.attentions,
366
+ )
367
+
368
+
369
+ class AngularLinear(nn.Module):
370
+
371
+ def __init__(self, in_features: int, out_features: int):
372
+ super(AngularLinear, self).__init__()
373
+ self.in_features = in_features
374
+ self.out_features = out_features
375
+ self.weight = torch.nn.Parameter(
376
+ torch.FloatTensor(out_features, in_features), requires_grad=True
377
+ )
378
+ nn.init.xavier_normal_(self.weight, gain=1)
379
+
380
+ def forward(
381
+ self,
382
+ inputs: torch.Tensor,
383
+ ):
384
+ # Calculation of cos(theta)
385
+ cosine = F.linear(F.normalize(inputs), F.normalize(self.weight))
386
+ return cosine
387
+
388
+ def extra_repr(self) -> str:
389
+ return 'in_features={}, out_features={}'.format(
390
+ self.in_features, self.out_features
391
+ )
392
+
393
+
394
+ class AMSoftmaxLoss(nn.Module):
395
+ """Additive Margin Softmax (CosFace).
396
+
397
+ Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
398
+ IEEE Signal Processing Letters 25.7 (2018): 926-930.
399
+ """
400
+ def __init__(
401
+ self,
402
+ scale: float = 30.0,
403
+ margin: float = 0.35,
404
+ label_smoothing: float = 0.0,
405
+ reduction: str = "mean"
406
+ ):
407
+ """
408
+ Args:
409
+ num_classes: Number of classes (output dimension)
410
+ scale: Scaling factor for logits (default: 30.0)
411
+ margin: Angular margin (default: 0.35)
412
+ """
413
+ super(AMSoftmaxLoss, self).__init__()
414
+ self.scale = scale
415
+ self.margin = margin
416
+ self.label_smoothing = label_smoothing
417
+ self.reduction = reduction
418
+
419
+ def forward(
420
+ self,
421
+ inputs: torch.Tensor,
422
+ targets: torch.Tensor,
423
+ ):
424
+ """
425
+ Args:
426
+ inputs: Input features of shape (batch_size, num_labels)
427
+ targets: Ground truth labels of shape (batch_size)
428
+ label_smoothing: Label smoothing factor (default: 0.0)
429
+ reduction: Reduction method (default: "mean")
430
+ Returns:
431
+ Loss value
432
+ """
433
+ _, num_labels = inputs.shape
434
+ # `inputs` are the outputs from AngularLinear()
435
+ cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
436
+ psi = cos_theta - self.margin
437
+ one_hot = nn.functional.one_hot(targets, num_labels)
438
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
439
+ loss = F.cross_entropy(
440
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
441
+ )
442
+ return loss
443
+
444
+
445
+ class AAMSoftmaxLoss(nn.Module):
446
+ """Additive Angular Margin Softmax (ArcFace).
447
+
448
+ Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
449
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
450
+ """
451
+ def __init__(
452
+ self,
453
+ scale: float = 30.0,
454
+ margin: float = 0.35,
455
+ easy_margin: bool = False,
456
+ label_smoothing: float = 0.0,
457
+ reduction: str = "mean"
458
+ ):
459
+ """
460
+ Args:
461
+ num_classes: Number of classes (output dimension)
462
+ scale: Scaling factor for logits (default: 30.0)
463
+ margin: Angular margin (default: 0.35)
464
+ easy_margin: Use the easy margin loss (default: False)
465
+ """
466
+ super(AAMSoftmaxLoss, self).__init__()
467
+ self.scale = scale
468
+ self.margin = margin
469
+ self.easy_margin = easy_margin
470
+ self.label_smoothing = label_smoothing
471
+ self.reduction = reduction
472
+
473
+ def forward(
474
+ self,
475
+ inputs: torch.Tensor,
476
+ targets: torch.Tensor,
477
+ ):
478
+ """
479
+ Args:
480
+ inputs: Input features of shape (batch_size, num_labels)
481
+ targets: Ground truth labels of shape (batch_size)
482
+ Returns:
483
+ Loss value
484
+ """
485
+ _, num_labels = inputs.shape
486
+ # `inputs` are the outputs from AngularLinear()
487
+ cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
488
+ theta = torch.acos(cos_theta)
489
+ psi = torch.cos(theta + self.margin)
490
+ one_hot = nn.functional.one_hot(targets, num_labels)
491
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
492
+ loss = F.cross_entropy(
493
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
494
+ )
495
+ return loss
496
+
497
+
498
+ class WhisperSpkRegForSequenceClassification(WhisperSpkRegPreTrainedModel):
499
+
500
+ def __init__(self, config):
501
+ super().__init__(config)
502
+
503
+ self.encoder = WhisperEncoder(config)
504
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
505
+ if config.use_weighted_layer_sum:
506
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
507
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
508
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
509
+
510
+ # Initialize weights and apply final processing
511
+ self.post_init()
512
+
513
+ def freeze_encoder(self):
514
+ """
515
+ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
516
+ not be updated during training. Only the projection layers and classification head will be updated.
517
+ """
518
+ self.encoder._freeze_parameters()
519
+
520
+ def get_input_embeddings(self) -> nn.Module:
521
+ return self.encoder.get_input_embeddings()
522
+
523
+ def set_input_embeddings(self, value: nn.Module):
524
+ self.encoder.set_input_embeddings(value)
525
+
526
+ def forward(
527
+ self,
528
+ input_features: Optional[torch.LongTensor] = None,
529
+ head_mask: Optional[torch.Tensor] = None,
530
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
531
+ labels: Optional[torch.LongTensor] = None,
532
+ output_attentions: Optional[bool] = None,
533
+ output_hidden_states: Optional[bool] = None,
534
+ return_dict: Optional[bool] = None,
535
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
536
+ r"""
537
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
538
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
539
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
540
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
541
+
542
+ Returns:
543
+
544
+ Example:
545
+
546
+ ```python
547
+ >>> import torch
548
+ >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
549
+ >>> from datasets import load_dataset
550
+
551
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
552
+ >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
553
+
554
+ >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
555
+ >>> sample = next(iter(ds))
556
+
557
+ >>> inputs = feature_extractor(
558
+ ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
559
+ ... )
560
+ >>> input_features = inputs.input_features
561
+
562
+ >>> with torch.no_grad():
563
+ ... logits = model(input_features).logits
564
+
565
+ >>> predicted_class_ids = torch.argmax(logits).item()
566
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
567
+ >>> predicted_label
568
+ 'Afrikaans'
569
+ ```"""
570
+
571
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
572
+ output_hidden_states = (
573
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
574
+ )
575
+ if self.config.use_weighted_layer_sum:
576
+ output_hidden_states = True
577
+ elif output_hidden_states is None:
578
+ output_hidden_states = self.config.output_hidden_states
579
+
580
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
581
+
582
+ if encoder_outputs is None:
583
+ encoder_outputs = self.encoder(
584
+ input_features,
585
+ head_mask=head_mask,
586
+ output_attentions=output_attentions,
587
+ output_hidden_states=output_hidden_states,
588
+ return_dict=return_dict,
589
+ )
590
+
591
+ if self.config.use_weighted_layer_sum:
592
+ hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
593
+ hidden_states = torch.stack(hidden_states, dim=1)
594
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
595
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
596
+ else:
597
+ hidden_states = encoder_outputs[0]
598
+
599
+ hidden_states = self.projector(hidden_states)
600
+ pooled_output = hidden_states.mean(dim=1)
601
+
602
+ logits = self.classifier(pooled_output)
603
+
604
+ loss = None
605
+ if labels is not None:
606
+ if self.config.loss_fct == 'cross_entropy':
607
+ loss_fct = nn.CrossEntropyLoss(
608
+ label_smoothing=self.config.label_smoothing,
609
+ reduction=self.config.reduction
610
+ )
611
+ elif self.config.loss_fct == 'additive_margin':
612
+ loss_fct = AMSoftmaxLoss(
613
+ scale=self.config.scale,
614
+ margin=self.config.margin,
615
+ label_smoothing=self.config.label_smoothing,
616
+ reduction=self.config.reduction
617
+ )
618
+ elif self.config.loss_fct == 'additive_angular_margin':
619
+ loss_fct = AAMSoftmaxLoss(
620
+ scale=self.config.scale,
621
+ margin=self.config.margin,
622
+ easy_margin=self.config.easy_margin,
623
+ label_smoothing=self.config.label_smoothing,
624
+ reduction=self.config.reduction
625
+ )
626
+ loss = loss_fct(
627
+ logits.view(-1, self.config.num_labels),
628
+ labels.view(-1).to(logits.device),
629
+ )
630
+
631
+ if not return_dict:
632
+ output = (logits,) + encoder_outputs[1:]
633
+ return ((loss,) + output) if loss is not None else output
634
+
635
+ return SequenceClassifierOutput(
636
+ loss=loss,
637
+ logits=logits,
638
+ hidden_states=encoder_outputs.hidden_states,
639
+ attentions=encoder_outputs.attentions,
640
+ )