|
import re |
|
import os |
|
import wget |
|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from models.rawnet import SincConv, Residual_block |
|
from models.classifiers import DeepFakeClassifier |
|
|
|
class ImageEncoder(nn.Module): |
|
def __init__(self, args): |
|
super(ImageEncoder, self).__init__() |
|
self.device = args.device |
|
self.args = args |
|
self.flatten = nn.Flatten() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.pretrained_image_encoder = args.pretrained_image_encoder |
|
self.freeze_image_encoder = args.freeze_image_encoder |
|
|
|
if self.pretrained_image_encoder == False: |
|
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device) |
|
|
|
else: |
|
self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device)) |
|
self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt) |
|
|
|
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device) |
|
print("Loading pretrained image encoder...") |
|
self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True) |
|
print("Loaded pretrained image encoder.") |
|
|
|
if self.freeze_image_encoder == True: |
|
for idx, param in self.model.named_parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
def forward(self, x): |
|
x = self.model(x) |
|
out = self.sigmoid(x) |
|
|
|
|
|
return out |
|
|
|
|
|
class RawNet(nn.Module): |
|
def __init__(self, args): |
|
super(RawNet, self).__init__() |
|
|
|
self.device=args.device |
|
self.filts = [20, [20, 20], [20, 128], [128, 128]] |
|
|
|
self.Sinc_conv=SincConv(device=self.device, |
|
out_channels = self.filts[0], |
|
kernel_size = 1024, |
|
in_channels = args.in_channels) |
|
|
|
self.first_bn = nn.BatchNorm1d(num_features = self.filts[0]) |
|
self.selu = nn.SELU(inplace=True) |
|
self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True)) |
|
self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1])) |
|
self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) |
|
self.filts[2][0] = self.filts[2][1] |
|
self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) |
|
self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) |
|
self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) |
|
self.avgpool = nn.AdaptiveAvgPool1d(1) |
|
|
|
self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1], |
|
l_out_features = self.filts[1][-1]) |
|
self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1], |
|
l_out_features = self.filts[1][-1]) |
|
self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1], |
|
l_out_features = self.filts[2][-1]) |
|
self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1], |
|
l_out_features = self.filts[2][-1]) |
|
self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1], |
|
l_out_features = self.filts[2][-1]) |
|
self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1], |
|
l_out_features = self.filts[2][-1]) |
|
|
|
self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1]) |
|
self.gru = nn.GRU(input_size = self.filts[2][-1], |
|
hidden_size = args.gru_node, |
|
num_layers = args.nb_gru_layer, |
|
batch_first = True) |
|
|
|
self.fc1_gru = nn.Linear(in_features = args.gru_node, |
|
out_features = args.nb_fc_node) |
|
|
|
self.fc2_gru = nn.Linear(in_features = args.nb_fc_node, |
|
out_features = args.nb_classes ,bias=True) |
|
|
|
self.sig = nn.Sigmoid() |
|
self.logsoftmax = nn.LogSoftmax(dim=1) |
|
self.pretrained_audio_encoder = args.pretrained_audio_encoder |
|
self.freeze_audio_encoder = args.freeze_audio_encoder |
|
|
|
if self.pretrained_audio_encoder == True: |
|
print("Loading pretrained audio encoder") |
|
ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device)) |
|
print("Loaded pretrained audio encoder") |
|
self.load_state_dict(ckpt, strict = True) |
|
|
|
if self.freeze_audio_encoder: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x, y = None): |
|
|
|
nb_samp = x.shape[0] |
|
len_seq = x.shape[1] |
|
x=x.view(nb_samp,1,len_seq) |
|
|
|
x = self.Sinc_conv(x) |
|
x = F.max_pool1d(torch.abs(x), 3) |
|
x = self.first_bn(x) |
|
x = self.selu(x) |
|
|
|
x0 = self.block0(x) |
|
y0 = self.avgpool(x0).view(x0.size(0), -1) |
|
y0 = self.fc_attention0(y0) |
|
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) |
|
x = x0 * y0 + y0 |
|
|
|
|
|
x1 = self.block1(x) |
|
y1 = self.avgpool(x1).view(x1.size(0), -1) |
|
y1 = self.fc_attention1(y1) |
|
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) |
|
x = x1 * y1 + y1 |
|
|
|
x2 = self.block2(x) |
|
y2 = self.avgpool(x2).view(x2.size(0), -1) |
|
y2 = self.fc_attention2(y2) |
|
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) |
|
x = x2 * y2 + y2 |
|
|
|
x3 = self.block3(x) |
|
y3 = self.avgpool(x3).view(x3.size(0), -1) |
|
y3 = self.fc_attention3(y3) |
|
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) |
|
x = x3 * y3 + y3 |
|
|
|
x4 = self.block4(x) |
|
y4 = self.avgpool(x4).view(x4.size(0), -1) |
|
y4 = self.fc_attention4(y4) |
|
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) |
|
x = x4 * y4 + y4 |
|
|
|
x5 = self.block5(x) |
|
y5 = self.avgpool(x5).view(x5.size(0), -1) |
|
y5 = self.fc_attention5(y5) |
|
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) |
|
x = x5 * y5 + y5 |
|
|
|
x = self.bn_before_gru(x) |
|
x = self.selu(x) |
|
x = x.permute(0, 2, 1) |
|
self.gru.flatten_parameters() |
|
x, _ = self.gru(x) |
|
x = x[:,-1,:] |
|
x = self.fc1_gru(x) |
|
x = self.fc2_gru(x) |
|
output=self.logsoftmax(x) |
|
|
|
return output |
|
|
|
|
|
|
|
def _make_attention_fc(self, in_features, l_out_features): |
|
|
|
l_fc = [] |
|
|
|
l_fc.append(nn.Linear(in_features = in_features, |
|
out_features = l_out_features)) |
|
|
|
|
|
|
|
return nn.Sequential(*l_fc) |
|
|
|
|
|
def _make_layer(self, nb_blocks, nb_filts, first = False): |
|
layers = [] |
|
|
|
for i in range(nb_blocks): |
|
first = first if i == 0 else False |
|
layers.append(Residual_block(nb_filts = nb_filts, |
|
first = first)) |
|
if i == 0: nb_filts[0] = nb_filts[1] |
|
|
|
return nn.Sequential(*layers) |