jimbozhang commited on
Commit
2e9a543
1 Parent(s): 9b348cb

Update 3 files

Browse files
ced_model/configuration_ced.py CHANGED
@@ -123,15 +123,12 @@ class CedConfig(PretrainedConfig):
123
  self.qkv_bias = qkv_bias
124
  self.target_length = target_length
125
  self.win_size = kwargs.get("win_size", 512)
 
126
 
127
  if self.outputdim == 527:
128
- with open(
129
- cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r"
130
- ) as f:
131
  self.id2label = {
132
- int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2]
133
- .replace('"', "")
134
- .strip("\n")
135
  for line in f.readlines()[1:]
136
  }
137
  self.label2id = {v: k for k, v in self.id2label.items()}
 
123
  self.qkv_bias = qkv_bias
124
  self.target_length = target_length
125
  self.win_size = kwargs.get("win_size", 512)
126
+ self.loss = "BCE"
127
 
128
  if self.outputdim == 527:
129
+ with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f:
 
 
130
  self.id2label = {
131
+ int(line.split(",", maxsplit=3)[0]): line.split(",", maxsplit=3)[2].replace('"', "").strip("\n")
 
 
132
  for line in f.readlines()[1:]
133
  }
134
  self.label2id = {v: k for k, v in self.id2label.items()}
ced_model/feature_extraction_ced.py CHANGED
@@ -16,7 +16,7 @@
16
  Feature extractor class for CED.
17
  """
18
 
19
- from typing import Optional, Union
20
 
21
  import numpy as np
22
  import torch
@@ -77,10 +77,14 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
77
  self.f_max = f_max
78
  self.hop_size = hop_size
79
 
 
 
80
  def __call__(
81
  self,
82
- x: Union[np.ndarray, torch.Tensor],
83
  sampling_rate: Optional[int] = None,
 
 
84
  return_tensors="pt",
85
  ) -> BatchFeature:
86
  r"""
@@ -88,6 +92,14 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
88
 
89
  Args:
90
  x: Input audio signal tensor.
 
 
 
 
 
 
 
 
91
 
92
  Returns:
93
  BatchFeature: A dictionary containing the extracted features.
@@ -96,9 +108,7 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
96
  sampling_rate = self.sampling_rate
97
 
98
  if return_tensors != "pt":
99
- raise NotImplementedError(
100
- "Only return_tensors='pt' is currently supported."
101
- )
102
 
103
  mel_spectrogram = audio_transforms.MelSpectrogram(
104
  f_min=self.f_min,
@@ -112,10 +122,42 @@ class CedFeatureExtractor(SequenceFeatureExtractor):
112
  )
113
  amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
114
 
115
- x = torch.from_numpy(x).float() if isinstance(x, np.ndarray) else x.float()
116
- if x.dim() == 1:
117
- x = x.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
119
  x = mel_spectrogram(x)
120
  x = amplitude_to_db(x)
121
  return BatchFeature({"input_values": x})
 
16
  Feature extractor class for CED.
17
  """
18
 
19
+ from typing import List, Optional, Union
20
 
21
  import numpy as np
22
  import torch
 
77
  self.f_max = f_max
78
  self.hop_size = hop_size
79
 
80
+ self.model_input_names = ["input_values"]
81
+
82
  def __call__(
83
  self,
84
+ x: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
85
  sampling_rate: Optional[int] = None,
86
+ max_length: Optional[int] = 16000,
87
+ truncation: bool = True,
88
  return_tensors="pt",
89
  ) -> BatchFeature:
90
  r"""
 
92
 
93
  Args:
94
  x: Input audio signal tensor.
95
+ sampling_rate (int, *optional*, defaults to `None`):
96
+ Sampling rate of the input audio signal.
97
+ max_length (int, *optional*, defaults to 16000):
98
+ Maximum length of the input audio signal.
99
+ truncation (bool, *optional*, defaults to `True`):
100
+ Whether to truncate the input signal to max_length.
101
+ return_tensors (str, *optional*, defaults to "pt"):
102
+ If set to "pt", the return type will be a PyTorch tensor.
103
 
104
  Returns:
105
  BatchFeature: A dictionary containing the extracted features.
 
108
  sampling_rate = self.sampling_rate
109
 
110
  if return_tensors != "pt":
111
+ raise NotImplementedError("Only return_tensors='pt' is currently supported.")
 
 
112
 
113
  mel_spectrogram = audio_transforms.MelSpectrogram(
114
  f_min=self.f_min,
 
122
  )
123
  amplitude_to_db = audio_transforms.AmplitudeToDB(top_db=120)
124
 
125
+ if isinstance(x, np.ndarray):
126
+ if x.ndim == 1:
127
+ x = x[np.newaxis, :]
128
+ if x.ndim != 2:
129
+ raise ValueError("np.ndarray input must be a 1D or 2D.")
130
+ x = torch.from_numpy(x)
131
+ elif isinstance(x, torch.Tensor):
132
+ if x.dim() == 1:
133
+ x = x.unsqueeze(0)
134
+ if x.dim() != 2:
135
+ raise ValueError("torch.Tensor input must be a 1D or 2D.")
136
+ elif isinstance(x, (list, tuple)):
137
+ longest_length = max(x_.shape[0] for x_ in x)
138
+ if not truncation and max_length < longest_length:
139
+ max_length = longest_length
140
+
141
+ if all(isinstance(x_, np.ndarray) for x_ in x):
142
+ if not all(x_.ndim == 1 for x_ in x):
143
+ raise ValueError("All np.ndarray in a list must be 1D.")
144
+
145
+ x_trim = [x_[:max_length] for x_ in x]
146
+ x_pad = [np.pad(x_, (0, max_length - x_.shape[0]), mode="constant", constant_values=0) for x_ in x_trim]
147
+ x = torch.stack([torch.from_numpy(x_) for x_ in x_pad])
148
+ elif all(isinstance(x_, torch.Tensor) for x_ in x):
149
+ if not all(x_.dim() == 1 for x_ in x):
150
+ raise ValueError("All torch.Tensor in a list must be 1D.")
151
+ x_pad = [torch.nn.functional.pad(x_, (0, max_length - x_.shape[0]), value=0) for x_ in x]
152
+ x = torch.stack(x_pad)
153
+ else:
154
+ raise ValueError("Input list must be numpy arrays or PyTorch tensors.")
155
+ else:
156
+ raise ValueError(
157
+ "Input must be a numpy array, a list of numpy arrays, a PyTorch tensor, or a list of PyTorch tensor."
158
+ )
159
 
160
+ x = x.float()
161
  x = mel_spectrogram(x)
162
  x = amplitude_to_db(x)
163
  return BatchFeature({"input_values": x})
ced_model/modeling_ced.py CHANGED
@@ -106,9 +106,7 @@ class CedAudioPatchEmbed(nn.Module):
106
  self.num_patches = self.grid_size[0] * self.grid_size[1]
107
  self.flatten = flatten
108
 
109
- self.proj = nn.Conv2d(
110
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride
111
- )
112
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
113
 
114
  def forward(self, x):
@@ -143,11 +141,7 @@ class CedAttention(nn.Module):
143
 
144
  def forward(self, x):
145
  B, N, C = x.shape
146
- qkv = (
147
- self.qkv(x)
148
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
149
- .permute(2, 0, 3, 1, 4)
150
- )
151
  q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
152
 
153
  attn = (q @ k.transpose(-2, -1)) * self.scale
@@ -221,9 +215,7 @@ class DropPath(nn.Module):
221
  return f"drop_prob={round(self.drop_prob,3):0.3f}"
222
 
223
 
224
- def drop_path(
225
- x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
226
- ):
227
  """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
228
 
229
  This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
@@ -236,9 +228,7 @@ def drop_path(
236
  if drop_prob == 0.0 or not training:
237
  return x
238
  keep_prob = 1 - drop_prob
239
- shape = (x.shape[0],) + (1,) * (
240
- x.ndim - 1
241
- ) # work with diff dim tensors, not just 2D ConvNets
242
  random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
243
  if keep_prob > 0.0 and scale_by_keep:
244
  random_tensor.div_(keep_prob)
@@ -373,17 +363,11 @@ class CedModel(CedPreTrainedModel):
373
  patch_stride=config.patch_stride,
374
  )
375
 
376
- self.time_pos_embed = nn.Parameter(
377
- torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02
378
- )
379
- self.freq_pos_embed = nn.Parameter(
380
- torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
381
- )
382
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
383
  act_layer = nn.GELU
384
- dpr = [
385
- x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)
386
- ] # stochastic depth decay rule
387
  self.pos_drop = nn.Dropout(p=config.drop_rate)
388
  self.blocks = nn.Sequential(
389
  *[
@@ -407,13 +391,16 @@ class CedModel(CedPreTrainedModel):
407
  # Initialize weights and apply final processing
408
  self.post_init()
409
 
 
 
 
 
 
410
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
411
  x = self.patch_embed(x)
412
  _, _, _, t = x.shape
413
  x = x + self.time_pos_embed[:, :, :, :t]
414
- x = (
415
- x + self.freq_pos_embed[:, :, :, :]
416
- ) # Just to support __getitem__ in posembed
417
 
418
  # x = rearrange(x, 'b c f t -> b (f t) c')
419
  x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
@@ -442,9 +429,7 @@ class CedModel(CedPreTrainedModel):
442
 
443
  if splits[-1].shape[-1] < self.maximal_allowed_length:
444
  if self.config.pad_last:
445
- pad = torch.zeros(
446
- *x.shape[:-1], self.maximal_allowed_length, device=x.device
447
- )
448
  pad[..., : splits[-1].shape[-1]] = splits[-1]
449
  splits = torch.stack((*splits[:-1], pad), dim=0)
450
  else:
@@ -497,9 +482,7 @@ class CedForAudioClassification(CedPreTrainedModel):
497
  elif self.config.pooling == "dm":
498
  # Unpack using the frequency dimension, which is constant
499
  # 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
500
- x = torch.reshape(
501
- x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3])
502
- )
503
 
504
  # First poolin frequency, then sigmoid the (B T D) output
505
  x = self.outputlayer(x.mean(1)).sigmoid()
@@ -507,9 +490,10 @@ class CedForAudioClassification(CedPreTrainedModel):
507
  else:
508
  return x.mean(1)
509
 
510
- @add_start_docstrings_to_model_forward(
511
- CED_INPUTS_DOCSTRING.format("batch_size, sequence_length")
512
- )
 
513
  @add_code_sample_docstrings(
514
  checkpoint=_SEQ_CLASS_CHECKPOINT,
515
  output_type=SequenceClassifierOutput,
@@ -519,9 +503,7 @@ class CedForAudioClassification(CedPreTrainedModel):
519
  expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
520
  expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
521
  )
522
- def forward(
523
- self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None
524
- ):
525
  """
526
  Runs a forward pass of the CED model for audio classification task.
527
 
@@ -554,14 +536,15 @@ class CedForAudioClassification(CedPreTrainedModel):
554
  logits = self.forward_head(last_hidden_states)
555
 
556
  if labels is not None:
557
- loss_fct = nn.BCEWithLogitsLoss()
558
- labels = nn.functional.one_hot(
559
- labels, num_classes=self.config.outputdim
560
- ).float()
 
 
 
561
  loss = loss_fct(logits, labels)
562
  else:
563
  loss = None
564
 
565
- return SequenceClassifierOutput(
566
- logits=logits, loss=loss, hidden_states=last_hidden_states
567
- )
 
106
  self.num_patches = self.grid_size[0] * self.grid_size[1]
107
  self.flatten = flatten
108
 
109
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride)
 
 
110
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
111
 
112
  def forward(self, x):
 
141
 
142
  def forward(self, x):
143
  B, N, C = x.shape
144
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
145
  q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
146
 
147
  attn = (q @ k.transpose(-2, -1)) * self.scale
 
215
  return f"drop_prob={round(self.drop_prob,3):0.3f}"
216
 
217
 
218
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
 
 
219
  """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
220
 
221
  This is the same as the DropConnect impl I (https://github.com/rwightman) created for EfficientNet, etc networks,
 
228
  if drop_prob == 0.0 or not training:
229
  return x
230
  keep_prob = 1 - drop_prob
231
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
 
 
232
  random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
233
  if keep_prob > 0.0 and scale_by_keep:
234
  random_tensor.div_(keep_prob)
 
363
  patch_stride=config.patch_stride,
364
  )
365
 
366
+ self.time_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02)
367
+ self.freq_pos_embed = nn.Parameter(torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02)
 
 
 
 
368
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
369
  act_layer = nn.GELU
370
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)] # stochastic depth decay rule
 
 
371
  self.pos_drop = nn.Dropout(p=config.drop_rate)
372
  self.blocks = nn.Sequential(
373
  *[
 
391
  # Initialize weights and apply final processing
392
  self.post_init()
393
 
394
+ def _freeze_parameters(self):
395
+ for param in self.parameters():
396
+ param.requires_grad = False
397
+ self._requires_grad = False
398
+
399
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
400
  x = self.patch_embed(x)
401
  _, _, _, t = x.shape
402
  x = x + self.time_pos_embed[:, :, :, :t]
403
+ x = x + self.freq_pos_embed[:, :, :, :] # Just to support __getitem__ in posembed
 
 
404
 
405
  # x = rearrange(x, 'b c f t -> b (f t) c')
406
  x = torch.permute(torch.flatten(x, 2, 3), (0, 2, 1))
 
429
 
430
  if splits[-1].shape[-1] < self.maximal_allowed_length:
431
  if self.config.pad_last:
432
+ pad = torch.zeros(*x.shape[:-1], self.maximal_allowed_length, device=x.device)
 
 
433
  pad[..., : splits[-1].shape[-1]] = splits[-1]
434
  splits = torch.stack((*splits[:-1], pad), dim=0)
435
  else:
 
482
  elif self.config.pooling == "dm":
483
  # Unpack using the frequency dimension, which is constant
484
  # 'b (f t) d -> b f t d', f=self.patch_embed.grid_size[0])
485
+ x = torch.reshape(x, (x.shape[0], self.patch_embed.grid_size[0], -1, x.shape[3]))
 
 
486
 
487
  # First poolin frequency, then sigmoid the (B T D) output
488
  x = self.outputlayer(x.mean(1)).sigmoid()
 
490
  else:
491
  return x.mean(1)
492
 
493
+ def freeze_encoder(self):
494
+ self.encoder._freeze_parameters()
495
+
496
+ @add_start_docstrings_to_model_forward(CED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
497
  @add_code_sample_docstrings(
498
  checkpoint=_SEQ_CLASS_CHECKPOINT,
499
  output_type=SequenceClassifierOutput,
 
503
  expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
504
  expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
505
  )
506
+ def forward(self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
 
 
507
  """
508
  Runs a forward pass of the CED model for audio classification task.
509
 
 
536
  logits = self.forward_head(last_hidden_states)
537
 
538
  if labels is not None:
539
+ if self.config.loss == "CE":
540
+ loss_fct = nn.CrossEntropyLoss()
541
+ elif self.config.loss == "BCE":
542
+ loss_fct = nn.BCEWithLogitsLoss()
543
+ else:
544
+ raise NotImplementedError("Need to set 'CE' or 'BCE' as config.loss.")
545
+ labels = nn.functional.one_hot(labels, num_classes=self.config.outputdim).float()
546
  loss = loss_fct(logits, labels)
547
  else:
548
  loss = None
549
 
550
+ return SequenceClassifierOutput(logits=logits, loss=loss, hidden_states=last_hidden_states)