Christian J. Steinmetz commited on
Commit
3c4fcfb
β€’
1 Parent(s): 4d2eb76

updating classifier configs and adding in kwargs to pretrained models

Browse files
cfg/model/cls_panns_16k.yaml CHANGED
@@ -11,5 +11,5 @@ model:
11
  hop_length: 512
12
  n_mels: 128
13
  sample_rate: ${sample_rate}
14
- model_sample_rate: ${sample_rate}
15
 
 
11
  hop_length: 512
12
  n_mels: 128
13
  sample_rate: ${sample_rate}
14
+ model_sample_rate: 16000
15
 
cfg/model/{cls_panns_44k.yaml β†’ cls_panns_48k.yaml} RENAMED
File without changes
cfg/model/cls_panns_48k_64.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 64
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
17
+
cfg/model/{cls_panns_44k_mixup.yaml β†’ cls_panns_48k_mixup.yaml} RENAMED
File without changes
cfg/model/cls_panns_48k_specaugment.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: True
cfg/model/cls_panns_48k_specaugment_label_smoothing.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ label_smoothing: 0.15
9
+ network:
10
+ _target_: remfx.classifier.Cnn14
11
+ num_classes: ${num_classes}
12
+ n_fft: 2048
13
+ hop_length: 512
14
+ n_mels: 128
15
+ sample_rate: ${sample_rate}
16
+ model_sample_rate: ${sample_rate}
17
+ specaugment: True
cfg/model/cls_panns_pt.yaml CHANGED
@@ -4,6 +4,7 @@ model:
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
 
7
  network:
8
  _target_: remfx.classifier.PANNs
9
  num_classes: ${num_classes}
 
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
+ mixup: False
8
  network:
9
  _target_: remfx.classifier.PANNs
10
  num_classes: ${num_classes}
remfx/classifier.py CHANGED
@@ -31,7 +31,7 @@ class PANNs(torch.nn.Module):
31
  torch.nn.Linear(hidden_dim, num_classes),
32
  )
33
 
34
- def forward(self, x: torch.Tensor):
35
  with torch.no_grad():
36
  x = self.resample(x)
37
  embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
@@ -59,7 +59,7 @@ class Wav2CLIP(nn.Module):
59
  torch.nn.Linear(hidden_dim, num_classes),
60
  )
61
 
62
- def forward(self, x: torch.Tensor):
63
  with torch.no_grad():
64
  x = self.resample(x)
65
  embed = wav2clip_hear.get_scene_embeddings(
@@ -89,7 +89,7 @@ class VGGish(nn.Module):
89
  torch.nn.Linear(hidden_dim, num_classes),
90
  )
91
 
92
- def forward(self, x: torch.Tensor):
93
  with torch.no_grad():
94
  x = self.resample(x)
95
  embed = hearbaseline.vggish.get_scene_embeddings(
@@ -119,7 +119,7 @@ class wav2vec2(nn.Module):
119
  torch.nn.Linear(hidden_dim, num_classes),
120
  )
121
 
122
- def forward(self, x: torch.Tensor):
123
  with torch.no_grad():
124
  x = self.resample(x)
125
  embed = hearbaseline.wav2vec2.get_scene_embeddings(
@@ -179,6 +179,10 @@ class Cnn14(nn.Module):
179
  orig_freq=sample_rate, new_freq=model_sample_rate
180
  )
181
 
 
 
 
 
182
  def init_weight(self):
183
  init_bn(self.bn0)
184
  init_layer(self.fc1)
 
31
  torch.nn.Linear(hidden_dim, num_classes),
32
  )
33
 
34
+ def forward(self, x: torch.Tensor, **kwargs):
35
  with torch.no_grad():
36
  x = self.resample(x)
37
  embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
 
59
  torch.nn.Linear(hidden_dim, num_classes),
60
  )
61
 
62
+ def forward(self, x: torch.Tensor, **kwargs):
63
  with torch.no_grad():
64
  x = self.resample(x)
65
  embed = wav2clip_hear.get_scene_embeddings(
 
89
  torch.nn.Linear(hidden_dim, num_classes),
90
  )
91
 
92
+ def forward(self, x: torch.Tensor, **kwargs):
93
  with torch.no_grad():
94
  x = self.resample(x)
95
  embed = hearbaseline.vggish.get_scene_embeddings(
 
119
  torch.nn.Linear(hidden_dim, num_classes),
120
  )
121
 
122
+ def forward(self, x: torch.Tensor, **kwargs):
123
  with torch.no_grad():
124
  x = self.resample(x)
125
  embed = hearbaseline.wav2vec2.get_scene_embeddings(
 
179
  orig_freq=sample_rate, new_freq=model_sample_rate
180
  )
181
 
182
+ if self.specaugment:
183
+ self.freq_mask = torchaudio.transforms.FrequencyMasking(64, True)
184
+ self.time_mask = torchaudio.transforms.TimeMasking(128, True)
185
+
186
  def init_weight(self):
187
  init_bn(self.bn0)
188
  init_layer(self.fc1)