HoneyTian commited on
Commit
cae1e58
·
1 Parent(s): 32bdb94
examples/vm_sound_classification/conv2d_classifier.yaml CHANGED
@@ -10,6 +10,11 @@ mel_spectrogram_param:
10
  window_fn: hamming
11
  n_mels: 80
12
 
 
 
 
 
 
13
  conv2d_block_param_list:
14
  - batch_norm: true
15
  in_channels: 1
 
10
  window_fn: hamming
11
  n_mels: 80
12
 
13
+ spec_augment_param:
14
+ aug_volume_factor_range:
15
+ - 0.5
16
+ - 2.0
17
+
18
  conv2d_block_param_list:
19
  - batch_norm: true
20
  in_channels: 1
examples/vm_sound_classification/step_1_prepare_data.py CHANGED
@@ -70,7 +70,7 @@ def get_dataset(args):
70
  "mute": "mute",
71
  "noise": "noise",
72
  "noise_mute": "noise",
73
- "voice": "voice_or_noise",
74
  "voicemail": "voicemail",
75
  }
76
  # label8_map = {
 
70
  "mute": "mute",
71
  "noise": "noise",
72
  "noise_mute": "noise",
73
+ "voice": "voice",
74
  "voicemail": "voicemail",
75
  }
76
  # label8_map = {
toolbox/torchaudio/augment/spec_augment.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py
5
+ """
6
+ import random
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class SpecAugment(nn.Module):
14
+ def __init__(self,
15
+ aug_volume_factor_range: Tuple[float, float] = (0.5, 2.0),
16
+ ):
17
+ super().__init__()
18
+ self.aug_volume_factor_range = aug_volume_factor_range
19
+
20
+ @staticmethod
21
+ def augment_volume(spec: torch.Tensor, factor_range: Tuple[float, float] = (0.5, 2.0)):
22
+ factor = random.uniform(*factor_range)
23
+ spec_ = spec.clone().detach()
24
+ spec_ *= factor
25
+ return spec_
26
+
27
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
28
+ spec = self.augment_volume(spec, self.aug_volume_factor_range)
29
+ return spec
30
+
31
+
32
+ def main():
33
+ spec_augment = SpecAugment()
34
+
35
+ spec = torch.randn(size=(1, 10, 4))
36
+ print(spec)
37
+
38
+ spec_ = spec_augment.forward(spec)
39
+ print(spec_)
40
+ return
41
+
42
+
43
+ if __name__ == '__main__':
44
+ main()
toolbox/torchaudio/models/cnn_audio_classifier/configuration_cnn_audio_classifier.py CHANGED
@@ -8,6 +8,7 @@ from toolbox.torchaudio.configuration_utils import PretrainedConfig
8
  class CnnAudioClassifierConfig(PretrainedConfig):
9
  def __init__(self,
10
  mel_spectrogram_param: dict,
 
11
  cls_head_param: dict,
12
  conv1d_block_param_list: List[dict] = None,
13
  conv2d_block_param_list: List[dict] = None,
@@ -15,6 +16,7 @@ class CnnAudioClassifierConfig(PretrainedConfig):
15
  ):
16
  super(CnnAudioClassifierConfig, self).__init__(**kwargs)
17
  self.mel_spectrogram_param = mel_spectrogram_param
 
18
  self.cls_head_param = cls_head_param
19
  self.conv1d_block_param_list = conv1d_block_param_list
20
  self.conv2d_block_param_list = conv2d_block_param_list
 
8
  class CnnAudioClassifierConfig(PretrainedConfig):
9
  def __init__(self,
10
  mel_spectrogram_param: dict,
11
+ spec_augment_param: dict,
12
  cls_head_param: dict,
13
  conv1d_block_param_list: List[dict] = None,
14
  conv2d_block_param_list: List[dict] = None,
 
16
  ):
17
  super(CnnAudioClassifierConfig, self).__init__(**kwargs)
18
  self.mel_spectrogram_param = mel_spectrogram_param
19
+ self.spec_augment_param = spec_augment_param
20
  self.cls_head_param = cls_head_param
21
  self.conv1d_block_param_list = conv1d_block_param_list
22
  self.conv2d_block_param_list = conv2d_block_param_list
toolbox/torchaudio/models/cnn_audio_classifier/modeling_cnn_audio_classifier.py CHANGED
@@ -9,6 +9,7 @@ import torchaudio
9
 
10
  from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
11
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
 
12
 
13
 
14
  MODEL_FILE = "model.pt"
@@ -240,6 +241,7 @@ class SpectrogramEncoder(nn.Module):
240
  class WaveEncoder(nn.Module):
241
  def __init__(self,
242
  mel_spectrogram_param: dict,
 
243
  conv1d_block_param_list: List[dict] = None,
244
  conv2d_block_param_list: List[dict] = None,
245
  ):
@@ -262,6 +264,10 @@ class WaveEncoder(nn.Module):
262
  ),
263
  )
264
 
 
 
 
 
265
  self.spectrogram_encoder = SpectrogramEncoder(
266
  conv1d_block_param_list=conv1d_block_param_list,
267
  conv2d_block_param_list=conv2d_block_param_list,
@@ -277,6 +283,9 @@ class WaveEncoder(nn.Module):
277
  x = x.log()
278
  x = x - torch.mean(x, dim=-1, keepdim=True)
279
 
 
 
 
280
  x = x.transpose(1, 2)
281
 
282
  features = self.spectrogram_encoder.forward(x)
@@ -346,6 +355,7 @@ class WaveClassifierPretrainedModel(WaveClassifier):
346
  super(WaveClassifierPretrainedModel, self).__init__(
347
  wave_encoder=WaveEncoder(
348
  mel_spectrogram_param=config.mel_spectrogram_param,
 
349
  conv1d_block_param_list=config.conv1d_block_param_list,
350
  conv2d_block_param_list=config.conv2d_block_param_list,
351
  ),
 
9
 
10
  from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
11
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
+ from toolbox.torchaudio.augment.spec_augment import SpecAugment
13
 
14
 
15
  MODEL_FILE = "model.pt"
 
241
  class WaveEncoder(nn.Module):
242
  def __init__(self,
243
  mel_spectrogram_param: dict,
244
+ spec_augment_param: dict,
245
  conv1d_block_param_list: List[dict] = None,
246
  conv2d_block_param_list: List[dict] = None,
247
  ):
 
264
  ),
265
  )
266
 
267
+ self.spec_augment = SpecAugment(
268
+ aug_volume_factor_range=spec_augment_param["aug_volume_factor_range"]
269
+ )
270
+
271
  self.spectrogram_encoder = SpectrogramEncoder(
272
  conv1d_block_param_list=conv1d_block_param_list,
273
  conv2d_block_param_list=conv2d_block_param_list,
 
283
  x = x.log()
284
  x = x - torch.mean(x, dim=-1, keepdim=True)
285
 
286
+ if self.training:
287
+ x = self.spec_augment.forward(x)
288
+
289
  x = x.transpose(1, 2)
290
 
291
  features = self.spectrogram_encoder.forward(x)
 
355
  super(WaveClassifierPretrainedModel, self).__init__(
356
  wave_encoder=WaveEncoder(
357
  mel_spectrogram_param=config.mel_spectrogram_param,
358
+ spec_augment_param=config.spec_augment_param,
359
  conv1d_block_param_list=config.conv1d_block_param_list,
360
  conv2d_block_param_list=config.conv2d_block_param_list,
361
  ),