SivaResearch's picture
demo
b6d5990
raw
history blame contribute delete
No virus
7.87 kB
import re
import os
import wget
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from models.rawnet import SincConv, Residual_block
from models.classifiers import DeepFakeClassifier
class ImageEncoder(nn.Module):
def __init__(self, args):
super(ImageEncoder, self).__init__()
self.device = args.device
self.args = args
self.flatten = nn.Flatten()
self.sigmoid = nn.Sigmoid()
# self.fc = nn.Linear(in_features=2560, out_features = 2)
self.pretrained_image_encoder = args.pretrained_image_encoder
self.freeze_image_encoder = args.freeze_image_encoder
if self.pretrained_image_encoder == False:
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
else:
self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
print("Loading pretrained image encoder...")
self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True)
print("Loaded pretrained image encoder.")
if self.freeze_image_encoder == True:
for idx, param in self.model.named_parameters():
param.requires_grad = False
# self.model.fc = nn.Identity()
def forward(self, x):
x = self.model(x)
out = self.sigmoid(x)
# x = self.flatten(x)
# out = self.fc(x)
return out
class RawNet(nn.Module):
def __init__(self, args):
super(RawNet, self).__init__()
self.device=args.device
self.filts = [20, [20, 20], [20, 128], [128, 128]]
self.Sinc_conv=SincConv(device=self.device,
out_channels = self.filts[0],
kernel_size = 1024,
in_channels = args.in_channels)
self.first_bn = nn.BatchNorm1d(num_features = self.filts[0])
self.selu = nn.SELU(inplace=True)
self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True))
self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1]))
self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.filts[2][0] = self.filts[2][1]
self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1],
l_out_features = self.filts[1][-1])
self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1],
l_out_features = self.filts[1][-1])
self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1])
self.gru = nn.GRU(input_size = self.filts[2][-1],
hidden_size = args.gru_node,
num_layers = args.nb_gru_layer,
batch_first = True)
self.fc1_gru = nn.Linear(in_features = args.gru_node,
out_features = args.nb_fc_node)
self.fc2_gru = nn.Linear(in_features = args.nb_fc_node,
out_features = args.nb_classes ,bias=True)
self.sig = nn.Sigmoid()
self.logsoftmax = nn.LogSoftmax(dim=1)
self.pretrained_audio_encoder = args.pretrained_audio_encoder
self.freeze_audio_encoder = args.freeze_audio_encoder
if self.pretrained_audio_encoder == True:
print("Loading pretrained audio encoder")
ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device))
print("Loaded pretrained audio encoder")
self.load_state_dict(ckpt, strict = True)
if self.freeze_audio_encoder:
for param in self.parameters():
param.requires_grad = False
def forward(self, x, y = None):
nb_samp = x.shape[0]
len_seq = x.shape[1]
x=x.view(nb_samp,1,len_seq)
x = self.Sinc_conv(x)
x = F.max_pool1d(torch.abs(x), 3)
x = self.first_bn(x)
x = self.selu(x)
x0 = self.block0(x)
y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
y0 = self.fc_attention0(y0)
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
x1 = self.block1(x)
y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
y1 = self.fc_attention1(y1)
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
x2 = self.block2(x)
y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
y2 = self.fc_attention2(y2)
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
x3 = self.block3(x)
y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
y3 = self.fc_attention3(y3)
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
x4 = self.block4(x)
y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
y4 = self.fc_attention4(y4)
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
x5 = self.block5(x)
y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
y5 = self.fc_attention5(y5)
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
x = self.bn_before_gru(x)
x = self.selu(x)
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
self.gru.flatten_parameters()
x, _ = self.gru(x)
x = x[:,-1,:]
x = self.fc1_gru(x)
x = self.fc2_gru(x)
output=self.logsoftmax(x)
return output
def _make_attention_fc(self, in_features, l_out_features):
l_fc = []
l_fc.append(nn.Linear(in_features = in_features,
out_features = l_out_features))
return nn.Sequential(*l_fc)
def _make_layer(self, nb_blocks, nb_filts, first = False):
layers = []
#def __init__(self, nb_filts, first = False):
for i in range(nb_blocks):
first = first if i == 0 else False
layers.append(Residual_block(nb_filts = nb_filts,
first = first))
if i == 0: nb_filts[0] = nb_filts[1]
return nn.Sequential(*layers)