mattricesound commited on
Commit
08ea65f
1 Parent(s): 7e4b346

Remove hearbaseline

Browse files
Files changed (1) hide show
  1. remfx/classifier.py +0 -63
remfx/classifier.py CHANGED
@@ -1,9 +1,6 @@
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
- import hearbaseline
5
- import hearbaseline.vggish
6
- import hearbaseline.wav2vec2
7
 
8
  import wav2clip_hear
9
  import panns_hear
@@ -68,66 +65,6 @@ class Wav2CLIP(nn.Module):
68
  return self.proj(embed)
69
 
70
 
71
- class VGGish(nn.Module):
72
- def __init__(
73
- self,
74
- num_classes: int,
75
- sample_rate: float,
76
- hidden_dim: int = 256,
77
- ):
78
- super().__init__()
79
- self.num_classes = num_classes
80
- self.resample = torchaudio.transforms.Resample(
81
- orig_freq=sample_rate, new_freq=16000
82
- )
83
- self.model = hearbaseline.vggish.load_model()
84
- self.proj = torch.nn.Sequential(
85
- torch.nn.Linear(128, hidden_dim),
86
- torch.nn.ReLU(),
87
- torch.nn.Linear(hidden_dim, hidden_dim),
88
- torch.nn.ReLU(),
89
- torch.nn.Linear(hidden_dim, num_classes),
90
- )
91
-
92
- def forward(self, x: torch.Tensor, **kwargs):
93
- with torch.no_grad():
94
- x = self.resample(x)
95
- embed = hearbaseline.vggish.get_scene_embeddings(
96
- x.view(x.shape[0], -1), self.model
97
- )
98
- return self.proj(embed)
99
-
100
-
101
- class wav2vec2(nn.Module):
102
- def __init__(
103
- self,
104
- num_classes: int,
105
- sample_rate: float,
106
- hidden_dim: int = 256,
107
- ):
108
- super().__init__()
109
- self.num_classes = num_classes
110
- self.resample = torchaudio.transforms.Resample(
111
- orig_freq=sample_rate, new_freq=16000
112
- )
113
- self.model = hearbaseline.wav2vec2.load_model()
114
- self.proj = torch.nn.Sequential(
115
- torch.nn.Linear(1024, hidden_dim),
116
- torch.nn.ReLU(),
117
- torch.nn.Linear(hidden_dim, hidden_dim),
118
- torch.nn.ReLU(),
119
- torch.nn.Linear(hidden_dim, num_classes),
120
- )
121
-
122
- def forward(self, x: torch.Tensor, **kwargs):
123
- with torch.no_grad():
124
- x = self.resample(x)
125
- embed = hearbaseline.wav2vec2.get_scene_embeddings(
126
- x.view(x.shape[0], -1), self.model
127
- )
128
- return self.proj(embed)
129
-
130
-
131
  # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
132
 
133
 
 
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
 
 
 
4
 
5
  import wav2clip_hear
6
  import panns_hear
 
65
  return self.proj(embed)
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
69
 
70