Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- models/classifiers.py +172 -0
- models/efficientnet.onnx +3 -0
- models/image.py +195 -0
- models/links.txt +1 -0
- models/model.pth +3 -0
- models/rawnet.py +360 -0
models/classifiers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
|
6 |
+
tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
16 |
+
},
|
17 |
+
"tf_efficientnet_b2_ns": {
|
18 |
+
"features": 1408,
|
19 |
+
"init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
|
20 |
+
},
|
21 |
+
"tf_efficientnet_b4_ns": {
|
22 |
+
"features": 1792,
|
23 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
|
24 |
+
},
|
25 |
+
"tf_efficientnet_b5_ns": {
|
26 |
+
"features": 2048,
|
27 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
28 |
+
},
|
29 |
+
"tf_efficientnet_b4_ns_03d": {
|
30 |
+
"features": 1792,
|
31 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
|
32 |
+
},
|
33 |
+
"tf_efficientnet_b5_ns_03d": {
|
34 |
+
"features": 2048,
|
35 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
|
36 |
+
},
|
37 |
+
"tf_efficientnet_b5_ns_04d": {
|
38 |
+
"features": 2048,
|
39 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
|
40 |
+
},
|
41 |
+
"tf_efficientnet_b6_ns": {
|
42 |
+
"features": 2304,
|
43 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
|
44 |
+
},
|
45 |
+
"tf_efficientnet_b7_ns": {
|
46 |
+
"features": 2560,
|
47 |
+
"init_op": partial(tf_efficientnet_b7_ns, pretrained=False, drop_path_rate=0.2)
|
48 |
+
},
|
49 |
+
"tf_efficientnet_b6_ns_04d": {
|
50 |
+
"features": 2304,
|
51 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
|
57 |
+
"""Creates the SRM kernels for noise analysis."""
|
58 |
+
# note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
|
59 |
+
srm_kernel = torch.from_numpy(np.array([
|
60 |
+
[ # srm 1/2 horiz
|
61 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
62 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
63 |
+
[0., 1., -2., 1., 0.], # noqa: E241,E201
|
64 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
65 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
66 |
+
], [ # srm 1/4
|
67 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
68 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
69 |
+
[0., 2., -4., 2., 0.], # noqa: E241,E201
|
70 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
71 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
72 |
+
], [ # srm 1/12
|
73 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
74 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
75 |
+
[-2., 8., -12., 8., -2.], # noqa: E241,E201
|
76 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
77 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
78 |
+
]
|
79 |
+
])).float()
|
80 |
+
srm_kernel[0] /= 2
|
81 |
+
srm_kernel[1] /= 4
|
82 |
+
srm_kernel[2] /= 12
|
83 |
+
return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
|
84 |
+
|
85 |
+
|
86 |
+
def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
|
87 |
+
"""Creates a SRM convolution layer for noise analysis."""
|
88 |
+
weights = setup_srm_weights(input_channels)
|
89 |
+
conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
|
90 |
+
with torch.no_grad():
|
91 |
+
conv.weight = torch.nn.Parameter(weights, requires_grad=False)
|
92 |
+
return conv
|
93 |
+
|
94 |
+
|
95 |
+
class DeepFakeClassifierSRM(nn.Module):
|
96 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
99 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
100 |
+
self.srm_conv = setup_srm_layer(3)
|
101 |
+
self.dropout = Dropout(dropout_rate)
|
102 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
noise = self.srm_conv(x)
|
106 |
+
x = self.encoder.forward_features(noise)
|
107 |
+
x = self.avg_pool(x).flatten(1)
|
108 |
+
x = self.dropout(x)
|
109 |
+
x = self.fc(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class GlobalWeightedAvgPool2d(nn.Module):
|
114 |
+
"""
|
115 |
+
Global Weighted Average Pooling from paper "Global Weighted Average
|
116 |
+
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, features: int, flatten=False):
|
120 |
+
super().__init__()
|
121 |
+
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
122 |
+
self.flatten = flatten
|
123 |
+
|
124 |
+
def fscore(self, x):
|
125 |
+
m = self.conv(x)
|
126 |
+
m = m.sigmoid().exp()
|
127 |
+
return m
|
128 |
+
|
129 |
+
def norm(self, x: torch.Tensor):
|
130 |
+
return x / x.sum(dim=[2, 3], keepdim=True)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
input_x = x
|
134 |
+
x = self.fscore(x)
|
135 |
+
x = self.norm(x)
|
136 |
+
x = x * input_x
|
137 |
+
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class DeepFakeClassifier(nn.Module):
|
142 |
+
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
143 |
+
super().__init__()
|
144 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
145 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
146 |
+
self.dropout = Dropout(dropout_rate)
|
147 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.encoder.forward_features(x)
|
151 |
+
x = self.avg_pool(x).flatten(1)
|
152 |
+
x = self.dropout(x)
|
153 |
+
x = self.fc(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
class DeepFakeClassifierGWAP(nn.Module):
|
160 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
161 |
+
super().__init__()
|
162 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
163 |
+
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
164 |
+
self.dropout = Dropout(dropout_rate)
|
165 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = self.encoder.forward_features(x)
|
169 |
+
x = self.avg_pool(x).flatten(1)
|
170 |
+
x = self.dropout(x)
|
171 |
+
x = self.fc(x)
|
172 |
+
return x
|
models/efficientnet.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39774e1cc878ac2b587fd4dc1c96fba084c9fe5ee3106a43b560f6054a69ba26
|
3 |
+
size 133
|
models/image.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import wget
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from models.rawnet import SincConv, Residual_block
|
9 |
+
from models.classifiers import DeepFakeClassifier
|
10 |
+
|
11 |
+
class ImageEncoder(nn.Module):
|
12 |
+
def __init__(self, args):
|
13 |
+
super(ImageEncoder, self).__init__()
|
14 |
+
self.device = args.device
|
15 |
+
self.args = args
|
16 |
+
self.flatten = nn.Flatten()
|
17 |
+
self.sigmoid = nn.Sigmoid()
|
18 |
+
# self.fc = nn.Linear(in_features=2560, out_features = 2)
|
19 |
+
self.pretrained_image_encoder = args.pretrained_image_encoder
|
20 |
+
self.freeze_image_encoder = args.freeze_image_encoder
|
21 |
+
|
22 |
+
if self.pretrained_image_encoder == False:
|
23 |
+
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
|
24 |
+
|
25 |
+
else:
|
26 |
+
self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
|
27 |
+
self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
|
28 |
+
|
29 |
+
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
|
30 |
+
print("Loading pretrained image encoder...")
|
31 |
+
self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True)
|
32 |
+
print("Loaded pretrained image encoder.")
|
33 |
+
|
34 |
+
if self.freeze_image_encoder == True:
|
35 |
+
for idx, param in self.model.named_parameters():
|
36 |
+
param.requires_grad = False
|
37 |
+
|
38 |
+
# self.model.fc = nn.Identity()
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = self.model(x)
|
42 |
+
out = self.sigmoid(x)
|
43 |
+
# x = self.flatten(x)
|
44 |
+
# out = self.fc(x)
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
class RawNet(nn.Module):
|
49 |
+
def __init__(self, args):
|
50 |
+
super(RawNet, self).__init__()
|
51 |
+
|
52 |
+
self.device=args.device
|
53 |
+
self.filts = [20, [20, 20], [20, 128], [128, 128]]
|
54 |
+
|
55 |
+
self.Sinc_conv=SincConv(device=self.device,
|
56 |
+
out_channels = self.filts[0],
|
57 |
+
kernel_size = 1024,
|
58 |
+
in_channels = args.in_channels)
|
59 |
+
|
60 |
+
self.first_bn = nn.BatchNorm1d(num_features = self.filts[0])
|
61 |
+
self.selu = nn.SELU(inplace=True)
|
62 |
+
self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True))
|
63 |
+
self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1]))
|
64 |
+
self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
65 |
+
self.filts[2][0] = self.filts[2][1]
|
66 |
+
self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
67 |
+
self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
68 |
+
self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
69 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
70 |
+
|
71 |
+
self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1],
|
72 |
+
l_out_features = self.filts[1][-1])
|
73 |
+
self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1],
|
74 |
+
l_out_features = self.filts[1][-1])
|
75 |
+
self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1],
|
76 |
+
l_out_features = self.filts[2][-1])
|
77 |
+
self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1],
|
78 |
+
l_out_features = self.filts[2][-1])
|
79 |
+
self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1],
|
80 |
+
l_out_features = self.filts[2][-1])
|
81 |
+
self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1],
|
82 |
+
l_out_features = self.filts[2][-1])
|
83 |
+
|
84 |
+
self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1])
|
85 |
+
self.gru = nn.GRU(input_size = self.filts[2][-1],
|
86 |
+
hidden_size = args.gru_node,
|
87 |
+
num_layers = args.nb_gru_layer,
|
88 |
+
batch_first = True)
|
89 |
+
|
90 |
+
self.fc1_gru = nn.Linear(in_features = args.gru_node,
|
91 |
+
out_features = args.nb_fc_node)
|
92 |
+
|
93 |
+
self.fc2_gru = nn.Linear(in_features = args.nb_fc_node,
|
94 |
+
out_features = args.nb_classes ,bias=True)
|
95 |
+
|
96 |
+
self.sig = nn.Sigmoid()
|
97 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
98 |
+
self.pretrained_audio_encoder = args.pretrained_audio_encoder
|
99 |
+
self.freeze_audio_encoder = args.freeze_audio_encoder
|
100 |
+
|
101 |
+
if self.pretrained_audio_encoder == True:
|
102 |
+
print("Loading pretrained audio encoder")
|
103 |
+
ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device))
|
104 |
+
print("Loaded pretrained audio encoder")
|
105 |
+
self.load_state_dict(ckpt, strict = True)
|
106 |
+
|
107 |
+
if self.freeze_audio_encoder:
|
108 |
+
for param in self.parameters():
|
109 |
+
param.requires_grad = False
|
110 |
+
|
111 |
+
|
112 |
+
def forward(self, x, y = None):
|
113 |
+
|
114 |
+
nb_samp = x.shape[0]
|
115 |
+
len_seq = x.shape[1]
|
116 |
+
x=x.view(nb_samp,1,len_seq)
|
117 |
+
|
118 |
+
x = self.Sinc_conv(x)
|
119 |
+
x = F.max_pool1d(torch.abs(x), 3)
|
120 |
+
x = self.first_bn(x)
|
121 |
+
x = self.selu(x)
|
122 |
+
|
123 |
+
x0 = self.block0(x)
|
124 |
+
y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
|
125 |
+
y0 = self.fc_attention0(y0)
|
126 |
+
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
|
127 |
+
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
|
128 |
+
|
129 |
+
|
130 |
+
x1 = self.block1(x)
|
131 |
+
y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
|
132 |
+
y1 = self.fc_attention1(y1)
|
133 |
+
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
|
134 |
+
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
|
135 |
+
|
136 |
+
x2 = self.block2(x)
|
137 |
+
y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
|
138 |
+
y2 = self.fc_attention2(y2)
|
139 |
+
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
|
140 |
+
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
|
141 |
+
|
142 |
+
x3 = self.block3(x)
|
143 |
+
y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
|
144 |
+
y3 = self.fc_attention3(y3)
|
145 |
+
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
|
146 |
+
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
|
147 |
+
|
148 |
+
x4 = self.block4(x)
|
149 |
+
y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
|
150 |
+
y4 = self.fc_attention4(y4)
|
151 |
+
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
|
152 |
+
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
|
153 |
+
|
154 |
+
x5 = self.block5(x)
|
155 |
+
y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
|
156 |
+
y5 = self.fc_attention5(y5)
|
157 |
+
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
|
158 |
+
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
|
159 |
+
|
160 |
+
x = self.bn_before_gru(x)
|
161 |
+
x = self.selu(x)
|
162 |
+
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
|
163 |
+
self.gru.flatten_parameters()
|
164 |
+
x, _ = self.gru(x)
|
165 |
+
x = x[:,-1,:]
|
166 |
+
x = self.fc1_gru(x)
|
167 |
+
x = self.fc2_gru(x)
|
168 |
+
output=self.logsoftmax(x)
|
169 |
+
|
170 |
+
return output
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
def _make_attention_fc(self, in_features, l_out_features):
|
175 |
+
|
176 |
+
l_fc = []
|
177 |
+
|
178 |
+
l_fc.append(nn.Linear(in_features = in_features,
|
179 |
+
out_features = l_out_features))
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
return nn.Sequential(*l_fc)
|
184 |
+
|
185 |
+
|
186 |
+
def _make_layer(self, nb_blocks, nb_filts, first = False):
|
187 |
+
layers = []
|
188 |
+
#def __init__(self, nb_filts, first = False):
|
189 |
+
for i in range(nb_blocks):
|
190 |
+
first = first if i == 0 else False
|
191 |
+
layers.append(Residual_block(nb_filts = nb_filts,
|
192 |
+
first = first))
|
193 |
+
if i == 0: nb_filts[0] = nb_filts[1]
|
194 |
+
|
195 |
+
return nn.Sequential(*layers)
|
models/links.txt
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0315e9ad76374c2e0f91249847d4b1c8ad8c2b20ac334836e8e79657daa4b63a
|
3 |
+
size 134
|
models/rawnet.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import Tensor
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils import data
|
7 |
+
from collections import OrderedDict
|
8 |
+
from torch.nn.parameter import Parameter
|
9 |
+
|
10 |
+
|
11 |
+
class SincConv(nn.Module):
|
12 |
+
@staticmethod
|
13 |
+
def to_mel(hz):
|
14 |
+
return 2595 * np.log10(1 + hz / 700)
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def to_hz(mel):
|
18 |
+
return 700 * (10 ** (mel / 2595) - 1)
|
19 |
+
|
20 |
+
|
21 |
+
def __init__(self, device,out_channels, kernel_size,in_channels=1,sample_rate=16000,
|
22 |
+
stride=1, padding=0, dilation=1, bias=False, groups=1):
|
23 |
+
|
24 |
+
super(SincConv,self).__init__()
|
25 |
+
|
26 |
+
if in_channels != 1:
|
27 |
+
|
28 |
+
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
|
29 |
+
raise ValueError(msg)
|
30 |
+
|
31 |
+
self.out_channels = out_channels
|
32 |
+
self.kernel_size = kernel_size
|
33 |
+
self.sample_rate=sample_rate
|
34 |
+
|
35 |
+
# Forcing the filters to be odd (i.e, perfectly symmetrics)
|
36 |
+
if kernel_size%2==0:
|
37 |
+
self.kernel_size=self.kernel_size+1
|
38 |
+
|
39 |
+
self.device=device
|
40 |
+
self.stride = stride
|
41 |
+
self.padding = padding
|
42 |
+
self.dilation = dilation
|
43 |
+
|
44 |
+
if bias:
|
45 |
+
raise ValueError('SincConv does not support bias.')
|
46 |
+
if groups > 1:
|
47 |
+
raise ValueError('SincConv does not support groups.')
|
48 |
+
|
49 |
+
|
50 |
+
# initialize filterbanks using Mel scale
|
51 |
+
NFFT = 512
|
52 |
+
f=int(self.sample_rate/2)*np.linspace(0,1,int(NFFT/2)+1)
|
53 |
+
fmel=self.to_mel(f) # Hz to mel conversion
|
54 |
+
fmelmax=np.max(fmel)
|
55 |
+
fmelmin=np.min(fmel)
|
56 |
+
filbandwidthsmel=np.linspace(fmelmin,fmelmax,self.out_channels+1)
|
57 |
+
filbandwidthsf=self.to_hz(filbandwidthsmel) # Mel to Hz conversion
|
58 |
+
self.mel=filbandwidthsf
|
59 |
+
self.hsupp=torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2+1)
|
60 |
+
self.band_pass=torch.zeros(self.out_channels,self.kernel_size)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def forward(self,x):
|
65 |
+
for i in range(len(self.mel)-1):
|
66 |
+
fmin=self.mel[i]
|
67 |
+
fmax=self.mel[i+1]
|
68 |
+
hHigh=(2*fmax/self.sample_rate)*np.sinc(2*fmax*self.hsupp/self.sample_rate)
|
69 |
+
hLow=(2*fmin/self.sample_rate)*np.sinc(2*fmin*self.hsupp/self.sample_rate)
|
70 |
+
hideal=hHigh-hLow
|
71 |
+
|
72 |
+
self.band_pass[i,:]=Tensor(np.hamming(self.kernel_size))*Tensor(hideal)
|
73 |
+
|
74 |
+
band_pass_filter=self.band_pass.to(self.device)
|
75 |
+
|
76 |
+
self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
|
77 |
+
|
78 |
+
return F.conv1d(x, self.filters, stride=self.stride,
|
79 |
+
padding=self.padding, dilation=self.dilation,
|
80 |
+
bias=None, groups=1)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
class Residual_block(nn.Module):
|
85 |
+
def __init__(self, nb_filts, first = False):
|
86 |
+
super(Residual_block, self).__init__()
|
87 |
+
self.first = first
|
88 |
+
|
89 |
+
if not self.first:
|
90 |
+
self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0])
|
91 |
+
|
92 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.3)
|
93 |
+
|
94 |
+
self.conv1 = nn.Conv1d(in_channels = nb_filts[0],
|
95 |
+
out_channels = nb_filts[1],
|
96 |
+
kernel_size = 3,
|
97 |
+
padding = 1,
|
98 |
+
stride = 1)
|
99 |
+
|
100 |
+
self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1])
|
101 |
+
self.conv2 = nn.Conv1d(in_channels = nb_filts[1],
|
102 |
+
out_channels = nb_filts[1],
|
103 |
+
padding = 1,
|
104 |
+
kernel_size = 3,
|
105 |
+
stride = 1)
|
106 |
+
|
107 |
+
if nb_filts[0] != nb_filts[1]:
|
108 |
+
self.downsample = True
|
109 |
+
self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0],
|
110 |
+
out_channels = nb_filts[1],
|
111 |
+
padding = 0,
|
112 |
+
kernel_size = 1,
|
113 |
+
stride = 1)
|
114 |
+
|
115 |
+
else:
|
116 |
+
self.downsample = False
|
117 |
+
self.mp = nn.MaxPool1d(3)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
identity = x
|
121 |
+
if not self.first:
|
122 |
+
out = self.bn1(x)
|
123 |
+
out = self.lrelu(out)
|
124 |
+
else:
|
125 |
+
out = x
|
126 |
+
|
127 |
+
out = self.conv1(x)
|
128 |
+
out = self.bn2(out)
|
129 |
+
out = self.lrelu(out)
|
130 |
+
out = self.conv2(out)
|
131 |
+
|
132 |
+
if self.downsample:
|
133 |
+
identity = self.conv_downsample(identity)
|
134 |
+
|
135 |
+
out += identity
|
136 |
+
out = self.mp(out)
|
137 |
+
return out
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
class RawNet(nn.Module):
|
144 |
+
def __init__(self, d_args, device):
|
145 |
+
super(RawNet, self).__init__()
|
146 |
+
|
147 |
+
|
148 |
+
self.device=device
|
149 |
+
|
150 |
+
self.Sinc_conv=SincConv(device=self.device,
|
151 |
+
out_channels = d_args['filts'][0],
|
152 |
+
kernel_size = d_args['first_conv'],
|
153 |
+
in_channels = d_args['in_channels']
|
154 |
+
)
|
155 |
+
|
156 |
+
self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0])
|
157 |
+
self.selu = nn.SELU(inplace=True)
|
158 |
+
self.block0 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1], first = True))
|
159 |
+
self.block1 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1]))
|
160 |
+
self.block2 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
161 |
+
d_args['filts'][2][0] = d_args['filts'][2][1]
|
162 |
+
self.block3 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
163 |
+
self.block4 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
164 |
+
self.block5 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
165 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
166 |
+
|
167 |
+
self.fc_attention0 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
|
168 |
+
l_out_features = d_args['filts'][1][-1])
|
169 |
+
self.fc_attention1 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
|
170 |
+
l_out_features = d_args['filts'][1][-1])
|
171 |
+
self.fc_attention2 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
172 |
+
l_out_features = d_args['filts'][2][-1])
|
173 |
+
self.fc_attention3 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
174 |
+
l_out_features = d_args['filts'][2][-1])
|
175 |
+
self.fc_attention4 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
176 |
+
l_out_features = d_args['filts'][2][-1])
|
177 |
+
self.fc_attention5 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
178 |
+
l_out_features = d_args['filts'][2][-1])
|
179 |
+
|
180 |
+
self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1])
|
181 |
+
self.gru = nn.GRU(input_size = d_args['filts'][2][-1],
|
182 |
+
hidden_size = d_args['gru_node'],
|
183 |
+
num_layers = d_args['nb_gru_layer'],
|
184 |
+
batch_first = True)
|
185 |
+
|
186 |
+
|
187 |
+
self.fc1_gru = nn.Linear(in_features = d_args['gru_node'],
|
188 |
+
out_features = d_args['nb_fc_node'])
|
189 |
+
|
190 |
+
self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'],
|
191 |
+
out_features = d_args['nb_classes'],bias=True)
|
192 |
+
|
193 |
+
|
194 |
+
self.sig = nn.Sigmoid()
|
195 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
196 |
+
|
197 |
+
def forward(self, x, y = None):
|
198 |
+
|
199 |
+
|
200 |
+
nb_samp = x.shape[0]
|
201 |
+
len_seq = x.shape[1]
|
202 |
+
x=x.view(nb_samp,1,len_seq)
|
203 |
+
|
204 |
+
x = self.Sinc_conv(x)
|
205 |
+
x = F.max_pool1d(torch.abs(x), 3)
|
206 |
+
x = self.first_bn(x)
|
207 |
+
x = self.selu(x)
|
208 |
+
|
209 |
+
x0 = self.block0(x)
|
210 |
+
y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
|
211 |
+
y0 = self.fc_attention0(y0)
|
212 |
+
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
|
213 |
+
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
|
214 |
+
|
215 |
+
|
216 |
+
x1 = self.block1(x)
|
217 |
+
y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
|
218 |
+
y1 = self.fc_attention1(y1)
|
219 |
+
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
|
220 |
+
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
|
221 |
+
|
222 |
+
x2 = self.block2(x)
|
223 |
+
y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
|
224 |
+
y2 = self.fc_attention2(y2)
|
225 |
+
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
|
226 |
+
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
|
227 |
+
|
228 |
+
x3 = self.block3(x)
|
229 |
+
y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
|
230 |
+
y3 = self.fc_attention3(y3)
|
231 |
+
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
|
232 |
+
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
|
233 |
+
|
234 |
+
x4 = self.block4(x)
|
235 |
+
y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
|
236 |
+
y4 = self.fc_attention4(y4)
|
237 |
+
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
|
238 |
+
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
|
239 |
+
|
240 |
+
x5 = self.block5(x)
|
241 |
+
y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
|
242 |
+
y5 = self.fc_attention5(y5)
|
243 |
+
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
|
244 |
+
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
|
245 |
+
|
246 |
+
x = self.bn_before_gru(x)
|
247 |
+
x = self.selu(x)
|
248 |
+
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
|
249 |
+
self.gru.flatten_parameters()
|
250 |
+
x, _ = self.gru(x)
|
251 |
+
x = x[:,-1,:]
|
252 |
+
x = self.fc1_gru(x)
|
253 |
+
x = self.fc2_gru(x)
|
254 |
+
output=self.logsoftmax(x)
|
255 |
+
print(f"Spec output shape: {output.shape}")
|
256 |
+
|
257 |
+
return output
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
def _make_attention_fc(self, in_features, l_out_features):
|
262 |
+
|
263 |
+
l_fc = []
|
264 |
+
|
265 |
+
l_fc.append(nn.Linear(in_features = in_features,
|
266 |
+
out_features = l_out_features))
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
return nn.Sequential(*l_fc)
|
271 |
+
|
272 |
+
|
273 |
+
def _make_layer(self, nb_blocks, nb_filts, first = False):
|
274 |
+
layers = []
|
275 |
+
#def __init__(self, nb_filts, first = False):
|
276 |
+
for i in range(nb_blocks):
|
277 |
+
first = first if i == 0 else False
|
278 |
+
layers.append(Residual_block(nb_filts = nb_filts,
|
279 |
+
first = first))
|
280 |
+
if i == 0: nb_filts[0] = nb_filts[1]
|
281 |
+
|
282 |
+
return nn.Sequential(*layers)
|
283 |
+
|
284 |
+
def summary(self, input_size, batch_size=-1, device="cuda", print_fn = None):
|
285 |
+
if print_fn == None: printfn = print
|
286 |
+
model = self
|
287 |
+
|
288 |
+
def register_hook(module):
|
289 |
+
def hook(module, input, output):
|
290 |
+
class_name = str(module.__class__).split(".")[-1].split("'")[0]
|
291 |
+
module_idx = len(summary)
|
292 |
+
|
293 |
+
m_key = "%s-%i" % (class_name, module_idx + 1)
|
294 |
+
summary[m_key] = OrderedDict()
|
295 |
+
summary[m_key]["input_shape"] = list(input[0].size())
|
296 |
+
summary[m_key]["input_shape"][0] = batch_size
|
297 |
+
if isinstance(output, (list, tuple)):
|
298 |
+
summary[m_key]["output_shape"] = [
|
299 |
+
[-1] + list(o.size())[1:] for o in output
|
300 |
+
]
|
301 |
+
else:
|
302 |
+
summary[m_key]["output_shape"] = list(output.size())
|
303 |
+
if len(summary[m_key]["output_shape"]) != 0:
|
304 |
+
summary[m_key]["output_shape"][0] = batch_size
|
305 |
+
|
306 |
+
params = 0
|
307 |
+
if hasattr(module, "weight") and hasattr(module.weight, "size"):
|
308 |
+
params += torch.prod(torch.LongTensor(list(module.weight.size())))
|
309 |
+
summary[m_key]["trainable"] = module.weight.requires_grad
|
310 |
+
if hasattr(module, "bias") and hasattr(module.bias, "size"):
|
311 |
+
params += torch.prod(torch.LongTensor(list(module.bias.size())))
|
312 |
+
summary[m_key]["nb_params"] = params
|
313 |
+
|
314 |
+
if (
|
315 |
+
not isinstance(module, nn.Sequential)
|
316 |
+
and not isinstance(module, nn.ModuleList)
|
317 |
+
and not (module == model)
|
318 |
+
):
|
319 |
+
hooks.append(module.register_forward_hook(hook))
|
320 |
+
|
321 |
+
device = device.lower()
|
322 |
+
assert device in [
|
323 |
+
"cuda",
|
324 |
+
"cpu",
|
325 |
+
], "Input device is not valid, please specify 'cuda' or 'cpu'"
|
326 |
+
|
327 |
+
if device == "cuda" and torch.cuda.is_available():
|
328 |
+
dtype = torch.cuda.FloatTensor
|
329 |
+
else:
|
330 |
+
dtype = torch.FloatTensor
|
331 |
+
if isinstance(input_size, tuple):
|
332 |
+
input_size = [input_size]
|
333 |
+
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
|
334 |
+
summary = OrderedDict()
|
335 |
+
hooks = []
|
336 |
+
model.apply(register_hook)
|
337 |
+
model(*x)
|
338 |
+
for h in hooks:
|
339 |
+
h.remove()
|
340 |
+
|
341 |
+
print_fn("----------------------------------------------------------------")
|
342 |
+
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
|
343 |
+
print_fn(line_new)
|
344 |
+
print_fn("================================================================")
|
345 |
+
total_params = 0
|
346 |
+
total_output = 0
|
347 |
+
trainable_params = 0
|
348 |
+
for layer in summary:
|
349 |
+
# input_shape, output_shape, trainable, nb_params
|
350 |
+
line_new = "{:>20} {:>25} {:>15}".format(
|
351 |
+
layer,
|
352 |
+
str(summary[layer]["output_shape"]),
|
353 |
+
"{0:,}".format(summary[layer]["nb_params"]),
|
354 |
+
)
|
355 |
+
total_params += summary[layer]["nb_params"]
|
356 |
+
total_output += np.prod(summary[layer]["output_shape"])
|
357 |
+
if "trainable" in summary[layer]:
|
358 |
+
if summary[layer]["trainable"] == True:
|
359 |
+
trainable_params += summary[layer]["nb_params"]
|
360 |
+
print_fn(line_new)
|