nimocodes commited on
Commit
2c8d5d8
·
verified ·
1 Parent(s): e3b8e9c

Upload 6 files

Browse files
models/classifiers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
6
+ tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
7
+ from torch import nn
8
+ from torch.nn.modules.dropout import Dropout
9
+ from torch.nn.modules.linear import Linear
10
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
16
+ },
17
+ "tf_efficientnet_b2_ns": {
18
+ "features": 1408,
19
+ "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
20
+ },
21
+ "tf_efficientnet_b4_ns": {
22
+ "features": 1792,
23
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
24
+ },
25
+ "tf_efficientnet_b5_ns": {
26
+ "features": 2048,
27
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
28
+ },
29
+ "tf_efficientnet_b4_ns_03d": {
30
+ "features": 1792,
31
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
32
+ },
33
+ "tf_efficientnet_b5_ns_03d": {
34
+ "features": 2048,
35
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
36
+ },
37
+ "tf_efficientnet_b5_ns_04d": {
38
+ "features": 2048,
39
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
40
+ },
41
+ "tf_efficientnet_b6_ns": {
42
+ "features": 2304,
43
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
44
+ },
45
+ "tf_efficientnet_b7_ns": {
46
+ "features": 2560,
47
+ "init_op": partial(tf_efficientnet_b7_ns, pretrained=False, drop_path_rate=0.2)
48
+ },
49
+ "tf_efficientnet_b6_ns_04d": {
50
+ "features": 2304,
51
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
52
+ },
53
+ }
54
+
55
+
56
+ def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
57
+ """Creates the SRM kernels for noise analysis."""
58
+ # note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
59
+ srm_kernel = torch.from_numpy(np.array([
60
+ [ # srm 1/2 horiz
61
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
62
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
63
+ [0., 1., -2., 1., 0.], # noqa: E241,E201
64
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
65
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
66
+ ], [ # srm 1/4
67
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
68
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
69
+ [0., 2., -4., 2., 0.], # noqa: E241,E201
70
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
71
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
72
+ ], [ # srm 1/12
73
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
74
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
75
+ [-2., 8., -12., 8., -2.], # noqa: E241,E201
76
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
77
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
78
+ ]
79
+ ])).float()
80
+ srm_kernel[0] /= 2
81
+ srm_kernel[1] /= 4
82
+ srm_kernel[2] /= 12
83
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
84
+
85
+
86
+ def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
87
+ """Creates a SRM convolution layer for noise analysis."""
88
+ weights = setup_srm_weights(input_channels)
89
+ conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
90
+ with torch.no_grad():
91
+ conv.weight = torch.nn.Parameter(weights, requires_grad=False)
92
+ return conv
93
+
94
+
95
+ class DeepFakeClassifierSRM(nn.Module):
96
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
97
+ super().__init__()
98
+ self.encoder = encoder_params[encoder]["init_op"]()
99
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
100
+ self.srm_conv = setup_srm_layer(3)
101
+ self.dropout = Dropout(dropout_rate)
102
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
103
+
104
+ def forward(self, x):
105
+ noise = self.srm_conv(x)
106
+ x = self.encoder.forward_features(noise)
107
+ x = self.avg_pool(x).flatten(1)
108
+ x = self.dropout(x)
109
+ x = self.fc(x)
110
+ return x
111
+
112
+
113
+ class GlobalWeightedAvgPool2d(nn.Module):
114
+ """
115
+ Global Weighted Average Pooling from paper "Global Weighted Average
116
+ Pooling Bridges Pixel-level Localization and Image-level Classification"
117
+ """
118
+
119
+ def __init__(self, features: int, flatten=False):
120
+ super().__init__()
121
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
122
+ self.flatten = flatten
123
+
124
+ def fscore(self, x):
125
+ m = self.conv(x)
126
+ m = m.sigmoid().exp()
127
+ return m
128
+
129
+ def norm(self, x: torch.Tensor):
130
+ return x / x.sum(dim=[2, 3], keepdim=True)
131
+
132
+ def forward(self, x):
133
+ input_x = x
134
+ x = self.fscore(x)
135
+ x = self.norm(x)
136
+ x = x * input_x
137
+ x = x.sum(dim=[2, 3], keepdim=not self.flatten)
138
+ return x
139
+
140
+
141
+ class DeepFakeClassifier(nn.Module):
142
+ def __init__(self, encoder, dropout_rate=0.0) -> None:
143
+ super().__init__()
144
+ self.encoder = encoder_params[encoder]["init_op"]()
145
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
146
+ self.dropout = Dropout(dropout_rate)
147
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
148
+
149
+ def forward(self, x):
150
+ x = self.encoder.forward_features(x)
151
+ x = self.avg_pool(x).flatten(1)
152
+ x = self.dropout(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+
157
+
158
+
159
+ class DeepFakeClassifierGWAP(nn.Module):
160
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
161
+ super().__init__()
162
+ self.encoder = encoder_params[encoder]["init_op"]()
163
+ self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
164
+ self.dropout = Dropout(dropout_rate)
165
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
166
+
167
+ def forward(self, x):
168
+ x = self.encoder.forward_features(x)
169
+ x = self.avg_pool(x).flatten(1)
170
+ x = self.dropout(x)
171
+ x = self.fc(x)
172
+ return x
models/efficientnet.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39774e1cc878ac2b587fd4dc1c96fba084c9fe5ee3106a43b560f6054a69ba26
3
+ size 133
models/image.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import wget
4
+ import torch
5
+ import torchvision
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from models.rawnet import SincConv, Residual_block
9
+ from models.classifiers import DeepFakeClassifier
10
+
11
+ class ImageEncoder(nn.Module):
12
+ def __init__(self, args):
13
+ super(ImageEncoder, self).__init__()
14
+ self.device = args.device
15
+ self.args = args
16
+ self.flatten = nn.Flatten()
17
+ self.sigmoid = nn.Sigmoid()
18
+ # self.fc = nn.Linear(in_features=2560, out_features = 2)
19
+ self.pretrained_image_encoder = args.pretrained_image_encoder
20
+ self.freeze_image_encoder = args.freeze_image_encoder
21
+
22
+ if self.pretrained_image_encoder == False:
23
+ self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
24
+
25
+ else:
26
+ self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
27
+ self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
28
+
29
+ self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
30
+ print("Loading pretrained image encoder...")
31
+ self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True)
32
+ print("Loaded pretrained image encoder.")
33
+
34
+ if self.freeze_image_encoder == True:
35
+ for idx, param in self.model.named_parameters():
36
+ param.requires_grad = False
37
+
38
+ # self.model.fc = nn.Identity()
39
+
40
+ def forward(self, x):
41
+ x = self.model(x)
42
+ out = self.sigmoid(x)
43
+ # x = self.flatten(x)
44
+ # out = self.fc(x)
45
+ return out
46
+
47
+
48
+ class RawNet(nn.Module):
49
+ def __init__(self, args):
50
+ super(RawNet, self).__init__()
51
+
52
+ self.device=args.device
53
+ self.filts = [20, [20, 20], [20, 128], [128, 128]]
54
+
55
+ self.Sinc_conv=SincConv(device=self.device,
56
+ out_channels = self.filts[0],
57
+ kernel_size = 1024,
58
+ in_channels = args.in_channels)
59
+
60
+ self.first_bn = nn.BatchNorm1d(num_features = self.filts[0])
61
+ self.selu = nn.SELU(inplace=True)
62
+ self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True))
63
+ self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1]))
64
+ self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
65
+ self.filts[2][0] = self.filts[2][1]
66
+ self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
67
+ self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
68
+ self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
69
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
70
+
71
+ self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1],
72
+ l_out_features = self.filts[1][-1])
73
+ self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1],
74
+ l_out_features = self.filts[1][-1])
75
+ self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1],
76
+ l_out_features = self.filts[2][-1])
77
+ self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1],
78
+ l_out_features = self.filts[2][-1])
79
+ self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1],
80
+ l_out_features = self.filts[2][-1])
81
+ self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1],
82
+ l_out_features = self.filts[2][-1])
83
+
84
+ self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1])
85
+ self.gru = nn.GRU(input_size = self.filts[2][-1],
86
+ hidden_size = args.gru_node,
87
+ num_layers = args.nb_gru_layer,
88
+ batch_first = True)
89
+
90
+ self.fc1_gru = nn.Linear(in_features = args.gru_node,
91
+ out_features = args.nb_fc_node)
92
+
93
+ self.fc2_gru = nn.Linear(in_features = args.nb_fc_node,
94
+ out_features = args.nb_classes ,bias=True)
95
+
96
+ self.sig = nn.Sigmoid()
97
+ self.logsoftmax = nn.LogSoftmax(dim=1)
98
+ self.pretrained_audio_encoder = args.pretrained_audio_encoder
99
+ self.freeze_audio_encoder = args.freeze_audio_encoder
100
+
101
+ if self.pretrained_audio_encoder == True:
102
+ print("Loading pretrained audio encoder")
103
+ ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device))
104
+ print("Loaded pretrained audio encoder")
105
+ self.load_state_dict(ckpt, strict = True)
106
+
107
+ if self.freeze_audio_encoder:
108
+ for param in self.parameters():
109
+ param.requires_grad = False
110
+
111
+
112
+ def forward(self, x, y = None):
113
+
114
+ nb_samp = x.shape[0]
115
+ len_seq = x.shape[1]
116
+ x=x.view(nb_samp,1,len_seq)
117
+
118
+ x = self.Sinc_conv(x)
119
+ x = F.max_pool1d(torch.abs(x), 3)
120
+ x = self.first_bn(x)
121
+ x = self.selu(x)
122
+
123
+ x0 = self.block0(x)
124
+ y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
125
+ y0 = self.fc_attention0(y0)
126
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
127
+ x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
128
+
129
+
130
+ x1 = self.block1(x)
131
+ y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
132
+ y1 = self.fc_attention1(y1)
133
+ y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
134
+ x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
135
+
136
+ x2 = self.block2(x)
137
+ y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
138
+ y2 = self.fc_attention2(y2)
139
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
140
+ x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
141
+
142
+ x3 = self.block3(x)
143
+ y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
144
+ y3 = self.fc_attention3(y3)
145
+ y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
146
+ x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
147
+
148
+ x4 = self.block4(x)
149
+ y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
150
+ y4 = self.fc_attention4(y4)
151
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
152
+ x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
153
+
154
+ x5 = self.block5(x)
155
+ y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
156
+ y5 = self.fc_attention5(y5)
157
+ y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
158
+ x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
159
+
160
+ x = self.bn_before_gru(x)
161
+ x = self.selu(x)
162
+ x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
163
+ self.gru.flatten_parameters()
164
+ x, _ = self.gru(x)
165
+ x = x[:,-1,:]
166
+ x = self.fc1_gru(x)
167
+ x = self.fc2_gru(x)
168
+ output=self.logsoftmax(x)
169
+
170
+ return output
171
+
172
+
173
+
174
+ def _make_attention_fc(self, in_features, l_out_features):
175
+
176
+ l_fc = []
177
+
178
+ l_fc.append(nn.Linear(in_features = in_features,
179
+ out_features = l_out_features))
180
+
181
+
182
+
183
+ return nn.Sequential(*l_fc)
184
+
185
+
186
+ def _make_layer(self, nb_blocks, nb_filts, first = False):
187
+ layers = []
188
+ #def __init__(self, nb_filts, first = False):
189
+ for i in range(nb_blocks):
190
+ first = first if i == 0 else False
191
+ layers.append(Residual_block(nb_filts = nb_filts,
192
+ first = first))
193
+ if i == 0: nb_filts[0] = nb_filts[1]
194
+
195
+ return nn.Sequential(*layers)
models/links.txt CHANGED
@@ -0,0 +1 @@
 
 
1
+
models/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0315e9ad76374c2e0f91249847d4b1c8ad8c2b20ac334836e8e79657daa4b63a
3
+ size 134
models/rawnet.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ import numpy as np
6
+ from torch.utils import data
7
+ from collections import OrderedDict
8
+ from torch.nn.parameter import Parameter
9
+
10
+
11
+ class SincConv(nn.Module):
12
+ @staticmethod
13
+ def to_mel(hz):
14
+ return 2595 * np.log10(1 + hz / 700)
15
+
16
+ @staticmethod
17
+ def to_hz(mel):
18
+ return 700 * (10 ** (mel / 2595) - 1)
19
+
20
+
21
+ def __init__(self, device,out_channels, kernel_size,in_channels=1,sample_rate=16000,
22
+ stride=1, padding=0, dilation=1, bias=False, groups=1):
23
+
24
+ super(SincConv,self).__init__()
25
+
26
+ if in_channels != 1:
27
+
28
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
29
+ raise ValueError(msg)
30
+
31
+ self.out_channels = out_channels
32
+ self.kernel_size = kernel_size
33
+ self.sample_rate=sample_rate
34
+
35
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
36
+ if kernel_size%2==0:
37
+ self.kernel_size=self.kernel_size+1
38
+
39
+ self.device=device
40
+ self.stride = stride
41
+ self.padding = padding
42
+ self.dilation = dilation
43
+
44
+ if bias:
45
+ raise ValueError('SincConv does not support bias.')
46
+ if groups > 1:
47
+ raise ValueError('SincConv does not support groups.')
48
+
49
+
50
+ # initialize filterbanks using Mel scale
51
+ NFFT = 512
52
+ f=int(self.sample_rate/2)*np.linspace(0,1,int(NFFT/2)+1)
53
+ fmel=self.to_mel(f) # Hz to mel conversion
54
+ fmelmax=np.max(fmel)
55
+ fmelmin=np.min(fmel)
56
+ filbandwidthsmel=np.linspace(fmelmin,fmelmax,self.out_channels+1)
57
+ filbandwidthsf=self.to_hz(filbandwidthsmel) # Mel to Hz conversion
58
+ self.mel=filbandwidthsf
59
+ self.hsupp=torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2+1)
60
+ self.band_pass=torch.zeros(self.out_channels,self.kernel_size)
61
+
62
+
63
+
64
+ def forward(self,x):
65
+ for i in range(len(self.mel)-1):
66
+ fmin=self.mel[i]
67
+ fmax=self.mel[i+1]
68
+ hHigh=(2*fmax/self.sample_rate)*np.sinc(2*fmax*self.hsupp/self.sample_rate)
69
+ hLow=(2*fmin/self.sample_rate)*np.sinc(2*fmin*self.hsupp/self.sample_rate)
70
+ hideal=hHigh-hLow
71
+
72
+ self.band_pass[i,:]=Tensor(np.hamming(self.kernel_size))*Tensor(hideal)
73
+
74
+ band_pass_filter=self.band_pass.to(self.device)
75
+
76
+ self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
77
+
78
+ return F.conv1d(x, self.filters, stride=self.stride,
79
+ padding=self.padding, dilation=self.dilation,
80
+ bias=None, groups=1)
81
+
82
+
83
+
84
+ class Residual_block(nn.Module):
85
+ def __init__(self, nb_filts, first = False):
86
+ super(Residual_block, self).__init__()
87
+ self.first = first
88
+
89
+ if not self.first:
90
+ self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0])
91
+
92
+ self.lrelu = nn.LeakyReLU(negative_slope=0.3)
93
+
94
+ self.conv1 = nn.Conv1d(in_channels = nb_filts[0],
95
+ out_channels = nb_filts[1],
96
+ kernel_size = 3,
97
+ padding = 1,
98
+ stride = 1)
99
+
100
+ self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1])
101
+ self.conv2 = nn.Conv1d(in_channels = nb_filts[1],
102
+ out_channels = nb_filts[1],
103
+ padding = 1,
104
+ kernel_size = 3,
105
+ stride = 1)
106
+
107
+ if nb_filts[0] != nb_filts[1]:
108
+ self.downsample = True
109
+ self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0],
110
+ out_channels = nb_filts[1],
111
+ padding = 0,
112
+ kernel_size = 1,
113
+ stride = 1)
114
+
115
+ else:
116
+ self.downsample = False
117
+ self.mp = nn.MaxPool1d(3)
118
+
119
+ def forward(self, x):
120
+ identity = x
121
+ if not self.first:
122
+ out = self.bn1(x)
123
+ out = self.lrelu(out)
124
+ else:
125
+ out = x
126
+
127
+ out = self.conv1(x)
128
+ out = self.bn2(out)
129
+ out = self.lrelu(out)
130
+ out = self.conv2(out)
131
+
132
+ if self.downsample:
133
+ identity = self.conv_downsample(identity)
134
+
135
+ out += identity
136
+ out = self.mp(out)
137
+ return out
138
+
139
+
140
+
141
+
142
+
143
+ class RawNet(nn.Module):
144
+ def __init__(self, d_args, device):
145
+ super(RawNet, self).__init__()
146
+
147
+
148
+ self.device=device
149
+
150
+ self.Sinc_conv=SincConv(device=self.device,
151
+ out_channels = d_args['filts'][0],
152
+ kernel_size = d_args['first_conv'],
153
+ in_channels = d_args['in_channels']
154
+ )
155
+
156
+ self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0])
157
+ self.selu = nn.SELU(inplace=True)
158
+ self.block0 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1], first = True))
159
+ self.block1 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1]))
160
+ self.block2 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
161
+ d_args['filts'][2][0] = d_args['filts'][2][1]
162
+ self.block3 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
163
+ self.block4 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
164
+ self.block5 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
165
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
166
+
167
+ self.fc_attention0 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
168
+ l_out_features = d_args['filts'][1][-1])
169
+ self.fc_attention1 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
170
+ l_out_features = d_args['filts'][1][-1])
171
+ self.fc_attention2 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
172
+ l_out_features = d_args['filts'][2][-1])
173
+ self.fc_attention3 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
174
+ l_out_features = d_args['filts'][2][-1])
175
+ self.fc_attention4 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
176
+ l_out_features = d_args['filts'][2][-1])
177
+ self.fc_attention5 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
178
+ l_out_features = d_args['filts'][2][-1])
179
+
180
+ self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1])
181
+ self.gru = nn.GRU(input_size = d_args['filts'][2][-1],
182
+ hidden_size = d_args['gru_node'],
183
+ num_layers = d_args['nb_gru_layer'],
184
+ batch_first = True)
185
+
186
+
187
+ self.fc1_gru = nn.Linear(in_features = d_args['gru_node'],
188
+ out_features = d_args['nb_fc_node'])
189
+
190
+ self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'],
191
+ out_features = d_args['nb_classes'],bias=True)
192
+
193
+
194
+ self.sig = nn.Sigmoid()
195
+ self.logsoftmax = nn.LogSoftmax(dim=1)
196
+
197
+ def forward(self, x, y = None):
198
+
199
+
200
+ nb_samp = x.shape[0]
201
+ len_seq = x.shape[1]
202
+ x=x.view(nb_samp,1,len_seq)
203
+
204
+ x = self.Sinc_conv(x)
205
+ x = F.max_pool1d(torch.abs(x), 3)
206
+ x = self.first_bn(x)
207
+ x = self.selu(x)
208
+
209
+ x0 = self.block0(x)
210
+ y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
211
+ y0 = self.fc_attention0(y0)
212
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
213
+ x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
214
+
215
+
216
+ x1 = self.block1(x)
217
+ y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
218
+ y1 = self.fc_attention1(y1)
219
+ y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
220
+ x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
221
+
222
+ x2 = self.block2(x)
223
+ y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
224
+ y2 = self.fc_attention2(y2)
225
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
226
+ x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
227
+
228
+ x3 = self.block3(x)
229
+ y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
230
+ y3 = self.fc_attention3(y3)
231
+ y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
232
+ x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
233
+
234
+ x4 = self.block4(x)
235
+ y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
236
+ y4 = self.fc_attention4(y4)
237
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
238
+ x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
239
+
240
+ x5 = self.block5(x)
241
+ y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
242
+ y5 = self.fc_attention5(y5)
243
+ y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
244
+ x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
245
+
246
+ x = self.bn_before_gru(x)
247
+ x = self.selu(x)
248
+ x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
249
+ self.gru.flatten_parameters()
250
+ x, _ = self.gru(x)
251
+ x = x[:,-1,:]
252
+ x = self.fc1_gru(x)
253
+ x = self.fc2_gru(x)
254
+ output=self.logsoftmax(x)
255
+ print(f"Spec output shape: {output.shape}")
256
+
257
+ return output
258
+
259
+
260
+
261
+ def _make_attention_fc(self, in_features, l_out_features):
262
+
263
+ l_fc = []
264
+
265
+ l_fc.append(nn.Linear(in_features = in_features,
266
+ out_features = l_out_features))
267
+
268
+
269
+
270
+ return nn.Sequential(*l_fc)
271
+
272
+
273
+ def _make_layer(self, nb_blocks, nb_filts, first = False):
274
+ layers = []
275
+ #def __init__(self, nb_filts, first = False):
276
+ for i in range(nb_blocks):
277
+ first = first if i == 0 else False
278
+ layers.append(Residual_block(nb_filts = nb_filts,
279
+ first = first))
280
+ if i == 0: nb_filts[0] = nb_filts[1]
281
+
282
+ return nn.Sequential(*layers)
283
+
284
+ def summary(self, input_size, batch_size=-1, device="cuda", print_fn = None):
285
+ if print_fn == None: printfn = print
286
+ model = self
287
+
288
+ def register_hook(module):
289
+ def hook(module, input, output):
290
+ class_name = str(module.__class__).split(".")[-1].split("'")[0]
291
+ module_idx = len(summary)
292
+
293
+ m_key = "%s-%i" % (class_name, module_idx + 1)
294
+ summary[m_key] = OrderedDict()
295
+ summary[m_key]["input_shape"] = list(input[0].size())
296
+ summary[m_key]["input_shape"][0] = batch_size
297
+ if isinstance(output, (list, tuple)):
298
+ summary[m_key]["output_shape"] = [
299
+ [-1] + list(o.size())[1:] for o in output
300
+ ]
301
+ else:
302
+ summary[m_key]["output_shape"] = list(output.size())
303
+ if len(summary[m_key]["output_shape"]) != 0:
304
+ summary[m_key]["output_shape"][0] = batch_size
305
+
306
+ params = 0
307
+ if hasattr(module, "weight") and hasattr(module.weight, "size"):
308
+ params += torch.prod(torch.LongTensor(list(module.weight.size())))
309
+ summary[m_key]["trainable"] = module.weight.requires_grad
310
+ if hasattr(module, "bias") and hasattr(module.bias, "size"):
311
+ params += torch.prod(torch.LongTensor(list(module.bias.size())))
312
+ summary[m_key]["nb_params"] = params
313
+
314
+ if (
315
+ not isinstance(module, nn.Sequential)
316
+ and not isinstance(module, nn.ModuleList)
317
+ and not (module == model)
318
+ ):
319
+ hooks.append(module.register_forward_hook(hook))
320
+
321
+ device = device.lower()
322
+ assert device in [
323
+ "cuda",
324
+ "cpu",
325
+ ], "Input device is not valid, please specify 'cuda' or 'cpu'"
326
+
327
+ if device == "cuda" and torch.cuda.is_available():
328
+ dtype = torch.cuda.FloatTensor
329
+ else:
330
+ dtype = torch.FloatTensor
331
+ if isinstance(input_size, tuple):
332
+ input_size = [input_size]
333
+ x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
334
+ summary = OrderedDict()
335
+ hooks = []
336
+ model.apply(register_hook)
337
+ model(*x)
338
+ for h in hooks:
339
+ h.remove()
340
+
341
+ print_fn("----------------------------------------------------------------")
342
+ line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
343
+ print_fn(line_new)
344
+ print_fn("================================================================")
345
+ total_params = 0
346
+ total_output = 0
347
+ trainable_params = 0
348
+ for layer in summary:
349
+ # input_shape, output_shape, trainable, nb_params
350
+ line_new = "{:>20} {:>25} {:>15}".format(
351
+ layer,
352
+ str(summary[layer]["output_shape"]),
353
+ "{0:,}".format(summary[layer]["nb_params"]),
354
+ )
355
+ total_params += summary[layer]["nb_params"]
356
+ total_output += np.prod(summary[layer]["output_shape"])
357
+ if "trainable" in summary[layer]:
358
+ if summary[layer]["trainable"] == True:
359
+ trainable_params += summary[layer]["nb_params"]
360
+ print_fn(line_new)