import re import os import wget import torch import torchvision import torch.nn as nn import torch.nn.functional as F from models.rawnet import SincConv, Residual_block from models.classifiers import DeepFakeClassifier class ImageEncoder(nn.Module): def __init__(self, args): super(ImageEncoder, self).__init__() self.device = args.device self.args = args self.flatten = nn.Flatten() self.sigmoid = nn.Sigmoid() # self.fc = nn.Linear(in_features=2560, out_features = 2) self.pretrained_image_encoder = args.pretrained_image_encoder self.freeze_image_encoder = args.freeze_image_encoder if self.pretrained_image_encoder == False: self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device) else: self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device)) self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt) self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device) print("Loading pretrained image encoder...") self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True) print("Loaded pretrained image encoder.") if self.freeze_image_encoder == True: for idx, param in self.model.named_parameters(): param.requires_grad = False # self.model.fc = nn.Identity() def forward(self, x): x = self.model(x) out = self.sigmoid(x) # x = self.flatten(x) # out = self.fc(x) return out class RawNet(nn.Module): def __init__(self, args): super(RawNet, self).__init__() self.device=args.device self.filts = [20, [20, 20], [20, 128], [128, 128]] self.Sinc_conv=SincConv(device=self.device, out_channels = self.filts[0], kernel_size = 1024, in_channels = args.in_channels) self.first_bn = nn.BatchNorm1d(num_features = self.filts[0]) self.selu = nn.SELU(inplace=True) self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True)) self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1])) self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) self.filts[2][0] = self.filts[2][1] self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2])) self.avgpool = nn.AdaptiveAvgPool1d(1) self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1], l_out_features = self.filts[1][-1]) self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1], l_out_features = self.filts[1][-1]) self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1], l_out_features = self.filts[2][-1]) self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1], l_out_features = self.filts[2][-1]) self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1], l_out_features = self.filts[2][-1]) self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1], l_out_features = self.filts[2][-1]) self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1]) self.gru = nn.GRU(input_size = self.filts[2][-1], hidden_size = args.gru_node, num_layers = args.nb_gru_layer, batch_first = True) self.fc1_gru = nn.Linear(in_features = args.gru_node, out_features = args.nb_fc_node) self.fc2_gru = nn.Linear(in_features = args.nb_fc_node, out_features = args.nb_classes ,bias=True) self.sig = nn.Sigmoid() self.logsoftmax = nn.LogSoftmax(dim=1) self.pretrained_audio_encoder = args.pretrained_audio_encoder self.freeze_audio_encoder = args.freeze_audio_encoder if self.pretrained_audio_encoder == True: print("Loading pretrained audio encoder") ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device)) print("Loaded pretrained audio encoder") self.load_state_dict(ckpt, strict = True) if self.freeze_audio_encoder: for param in self.parameters(): param.requires_grad = False def forward(self, x, y = None): nb_samp = x.shape[0] len_seq = x.shape[1] x=x.view(nb_samp,1,len_seq) x = self.Sinc_conv(x) x = F.max_pool1d(torch.abs(x), 3) x = self.first_bn(x) x = self.selu(x) x0 = self.block0(x) y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter]) y0 = self.fc_attention0(y0) y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1]) x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1) x1 = self.block1(x) y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter]) y1 = self.fc_attention1(y1) y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1]) x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1) x2 = self.block2(x) y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter]) y2 = self.fc_attention2(y2) y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1]) x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1) x3 = self.block3(x) y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter]) y3 = self.fc_attention3(y3) y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1]) x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1) x4 = self.block4(x) y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter]) y4 = self.fc_attention4(y4) y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1]) x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1) x5 = self.block5(x) y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter]) y5 = self.fc_attention5(y5) y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1]) x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1) x = self.bn_before_gru(x) x = self.selu(x) x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt) self.gru.flatten_parameters() x, _ = self.gru(x) x = x[:,-1,:] x = self.fc1_gru(x) x = self.fc2_gru(x) output=self.logsoftmax(x) return output def _make_attention_fc(self, in_features, l_out_features): l_fc = [] l_fc.append(nn.Linear(in_features = in_features, out_features = l_out_features)) return nn.Sequential(*l_fc) def _make_layer(self, nb_blocks, nb_filts, first = False): layers = [] #def __init__(self, nb_filts, first = False): for i in range(nb_blocks): first = first if i == 0 else False layers.append(Residual_block(nb_filts = nb_filts, first = first)) if i == 0: nb_filts[0] = nb_filts[1] return nn.Sequential(*layers)