P-FAD / src /models /meso_net.py
mrneuralnet's picture
Initial commit
3fb4562
raw
history blame
No virus
5.02 kB
"""
This code is modified version of MesoNet DeepFake detection solution
from FakeAVCeleb repository - https://github.com/DASH-Lab/FakeAVCeleb/blob/main/models/MesoNet.py.
"""
import torch
import torch.nn as nn
from src import frontends
class MesoInception4(nn.Module):
"""
Pytorch Implemention of MesoInception4
Author: Honggu Liu
Date: July 7, 2019
"""
def __init__(self, num_classes=1, **kwargs):
super().__init__()
self.fc1_dim = kwargs.get("fc1_dim", 1024)
input_channels = kwargs.get("input_channels", 3)
self.num_classes = num_classes
#InceptionLayer1
self.Incption1_conv1 = nn.Conv2d(input_channels, 1, 1, padding=0, bias=False)
self.Incption1_conv2_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
self.Incption1_conv3_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
self.Incption1_conv4_1 = nn.Conv2d(input_channels, 2, 1, padding=0, bias=False)
self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
self.Incption1_bn = nn.BatchNorm2d(11)
#InceptionLayer2
self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
self.Incption2_bn = nn.BatchNorm2d(12)
#Normal Layer
self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False)
self.relu = nn.ReLU(inplace=True)
self.leakyrelu = nn.LeakyReLU(0.1)
self.bn1 = nn.BatchNorm2d(16)
self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False)
self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4))
self.dropout = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(self.fc1_dim, 16)
self.fc2 = nn.Linear(16, num_classes)
#InceptionLayer
def InceptionLayer1(self, input):
x1 = self.Incption1_conv1(input)
x2 = self.Incption1_conv2_1(input)
x2 = self.Incption1_conv2_2(x2)
x3 = self.Incption1_conv3_1(input)
x3 = self.Incption1_conv3_2(x3)
x4 = self.Incption1_conv4_1(input)
x4 = self.Incption1_conv4_2(x4)
y = torch.cat((x1, x2, x3, x4), 1)
y = self.Incption1_bn(y)
y = self.maxpooling1(y)
return y
def InceptionLayer2(self, input):
x1 = self.Incption2_conv1(input)
x2 = self.Incption2_conv2_1(input)
x2 = self.Incption2_conv2_2(x2)
x3 = self.Incption2_conv3_1(input)
x3 = self.Incption2_conv3_2(x3)
x4 = self.Incption2_conv4_1(input)
x4 = self.Incption2_conv4_2(x4)
y = torch.cat((x1, x2, x3, x4), 1)
y = self.Incption2_bn(y)
y = self.maxpooling1(y)
return y
def forward(self, input):
x = self._compute_embedding(input)
return x
def _compute_embedding(self, input):
x = self.InceptionLayer1(input) #(Batch, 11, 128, 128)
x = self.InceptionLayer2(x) #(Batch, 12, 64, 64)
x = self.conv1(x) #(Batch, 16, 64 ,64)
x = self.relu(x)
x = self.bn1(x)
x = self.maxpooling1(x) #(Batch, 16, 32, 32)
x = self.conv2(x) #(Batch, 16, 32, 32)
x = self.relu(x)
x = self.bn1(x)
x = self.maxpooling2(x) #(Batch, 16, 8, 8)
x = x.view(x.size(0), -1) #(Batch, 16*8*8)
x = self.dropout(x)
x = nn.AdaptiveAvgPool1d(self.fc1_dim)(x)
x = self.fc1(x) #(Batch, 16) ### <-- o tu
x = self.leakyrelu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class FrontendMesoInception4(MesoInception4):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.device = kwargs['device']
frontend_name = kwargs.get("frontend_algorithm", [])
self.frontend = frontends.get_frontend(frontend_name)
print(f"Using {frontend_name} frontend")
def forward(self, x):
x = self.frontend(x)
x = self._compute_embedding(x)
return x
if __name__ == "__main__":
model = FrontendMesoInception4(
input_channels=2,
fc1_dim=1024,
device='cuda',
frontend_algorithm="lfcc"
)
def count_parameters(model) -> int:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return pytorch_total_params
print(count_parameters(model))