deepfake-detect / utils /architectures.py
Sara Mandelli
Update detector
6bd8735
raw
history blame
13.2 kB
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import (
round_filters,
round_repeats,
drop_connect,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
)
from efficientnet_pytorch.model import MBConvBlock
from torchvision.models import resnet
from pytorchcv.model_provider import get_model
class Head(nn.Module):
def __init__(self, in_f, out_f):
super(Head, self).__init__()
self.f = nn.Flatten()
self.l = nn.Linear(in_f, 512)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(512, out_f)
self.b1 = nn.BatchNorm1d(in_f)
self.b2 = nn.BatchNorm1d(512)
self.r = nn.ReLU()
def forward(self, x):
x = self.f(x)
x = self.b1(x)
x = self.d(x)
x = self.l(x)
x = self.r(x)
x = self.b2(x)
x = self.d(x)
out = self.o(x)
return out
class FCN(nn.Module):
def __init__(self, base, in_f, out_f):
super(FCN, self).__init__()
self.base = base
self.h1 = Head(in_f, out_f)
def forward(self, x):
x = self.base(x)
return self.h1(x)
class BaseFCN(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCN, self).__init__()
self.f = nn.Flatten()
self.l = nn.Linear(625, 256)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(256, n_classes)
def forward(self, x):
x = self.f(x)
x = self.l(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
class BaseFCNHigh(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCNHigh, self).__init__()
self.f = nn.Flatten()
self.l = nn.Linear(625, 512)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(512, n_classes)
def forward(self, x):
x = self.f(x)
x = self.l(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
class BaseFCN4(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCN4, self).__init__()
self.f = nn.Flatten()
self.l1 = nn.Linear(625, 512)
self.l2 = nn.Linear(512, 384)
self.l3 = nn.Linear(384, 256)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(256, n_classes)
def forward(self, x):
x = self.f(x)
x = self.l1(x)
x = self.d(x)
x = self.l2(x)
x = self.d(x)
x = self.l3(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
class BaseFCNBnR(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCNBnR, self).__init__()
self.f = nn.Flatten()
self.b1 = nn.BatchNorm1d(625)
self.b2 = nn.BatchNorm1d(256)
self.l = nn.Linear(625, 256)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(256, n_classes)
self.r = nn.ReLU()
def forward(self, x):
x = self.f(x)
x = self.b1(x)
x = self.d(x)
x = self.l(x)
x = self.r(x)
x = self.b2(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
def forward_resnet_conv(net, x, upto: int = 4):
"""
Forward ResNet only in its convolutional part
:param net:
:param x:
:param upto:
:return:
"""
x = net.conv1(x) # N / 2
x = net.bn1(x)
x = net.relu(x)
x = net.maxpool(x) # N / 4
if upto >= 1:
x = net.layer1(x) # N / 4
if upto >= 2:
x = net.layer2(x) # N / 8
if upto >= 3:
x = net.layer3(x) # N / 16
if upto >= 4:
x = net.layer4(x) # N / 32
return x
class FeatureExtractor(nn.Module):
"""
Abstract class to be extended when supporting features extraction.
It also provides standard normalized and parameters
"""
def features(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def get_trainable_parameters(self):
return self.parameters()
@staticmethod
def get_normalizer():
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class FeatureExtractorGray(nn.Module):
"""
Abstract class to be extended when supporting features extraction.
It also provides standard normalized and parameters
"""
def features(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def get_trainable_parameters(self):
return self.parameters()
@staticmethod
def get_normalizer():
return transforms.Normalize(mean=[0.479], std=[0.226])
class EfficientNetGen(FeatureExtractor):
def __init__(self, model: str, n_classes: int, pretrained: bool):
super(EfficientNetGen, self).__init__()
if pretrained:
self.efficientnet = EfficientNet.from_pretrained(model)
else:
self.efficientnet = EfficientNet.from_name(model)
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
del self.efficientnet._fc
def features(self, x: torch.Tensor) -> torch.Tensor:
x = self.efficientnet.extract_features(x)
x = self.efficientnet._avg_pooling(x)
x = x.flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.efficientnet._dropout(x)
x = self.classifier(x)
# x = F.softmax(x, dim=-1)
return x
class EfficientNetB0(EfficientNetGen):
def __init__(self, n_classes: int, pretrained: bool):
super(EfficientNetB0, self).__init__(model='efficientnet-b0', n_classes=n_classes, pretrained=pretrained)
class EfficientNetB4(EfficientNetGen):
def __init__(self, n_classes: int, pretrained: bool):
super(EfficientNetB4, self).__init__(model='efficientnet-b4', n_classes=n_classes, pretrained=pretrained)
class EfficientNetGenPostStem(FeatureExtractor):
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetGenPostStem, self).__init__()
if pretrained:
self.efficientnet = EfficientNet.from_pretrained(model)
else:
self.efficientnet = EfficientNet.from_name(model)
self.n_ir_blocks = n_ir_blocks
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
# modify STEM
in_channels = 3 # rgb
out_channels = round_filters(32, self.efficientnet._global_params)
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, bias=False)
self.init_blocks_args = self.efficientnet._blocks_args[0]
self.init_blocks_args = self.init_blocks_args._replace(output_filters=32)
self.init_block = MBConvBlock(self.init_blocks_args, self.efficientnet._global_params)
self.last_block_args = self.efficientnet._blocks_args[0]
self.last_block_args = self.last_block_args._replace(output_filters=32, stride=2)
self.last_block = MBConvBlock(self.last_block_args, self.efficientnet._global_params)
del self.efficientnet._fc
def features(self, x: torch.Tensor) -> torch.Tensor:
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
# init blocks
for b in range(self.n_ir_blocks - 1):
x = self.init_block(x, drop_connect_rate=0)
# last block
x = self.last_block(x, drop_connect_rate=0)
# standard blocks efficientNet:
for idx, block in enumerate(self.efficientnet._blocks):
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
x = self.efficientnet._avg_pooling(x)
x = x.flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.efficientnet._dropout(x)
x = self.classifier(x)
# x = F.softmax(x, dim=-1)
return x
class EfficientNetB0PostStemIR(EfficientNetGenPostStem):
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetB0PostStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
pretrained=pretrained, n_ir_blocks=n_ir_blocks)
class EfficientNetGenPreStem(FeatureExtractor):
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetGenPreStem, self).__init__()
if pretrained:
self.efficientnet = EfficientNet.from_pretrained(model)
else:
self.efficientnet = EfficientNet.from_name(model)
self.n_ir_blocks = n_ir_blocks
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
self.init_block_args = self.efficientnet._blocks_args[0]
self.init_block_args = self.init_block_args._replace(input_filters=3, output_filters=32)
self.init_block = MBConvBlock(self.init_block_args, self.efficientnet._global_params)
self.last_blocks_args = self.efficientnet._blocks_args[0]
self.last_blocks_args = self.last_blocks_args._replace(output_filters=32)
self.last_block = MBConvBlock(self.last_blocks_args, self.efficientnet._global_params)
# modify STEM
in_channels = 32
out_channels = round_filters(32, self.efficientnet._global_params)
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
del self.efficientnet._fc
def features(self, x: torch.Tensor) -> torch.Tensor:
# init block
x = self.init_block(x, drop_connect_rate=0)
# other blocks
for b in range(self.n_ir_blocks - 1):
x = self.last_block(x, drop_connect_rate=0)
# standard stem efficientNet:
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
# standard blocks efficientNet:
for idx, block in enumerate(self.efficientnet._blocks):
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
x = self.efficientnet._avg_pooling(x)
x = x.flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.efficientnet._dropout(x)
x = self.classifier(x)
# x = F.softmax(x, dim=-1)
return x
class EfficientNetB0PreStemIR(EfficientNetGenPreStem):
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetB0PreStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
pretrained=pretrained, n_ir_blocks=n_ir_blocks)
class ResNet50(FeatureExtractor):
def __init__(self, n_classes: int, pretrained: bool):
super(ResNet50, self).__init__()
self.resnet = resnet.resnet50(pretrained=pretrained)
self.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=n_classes)
del self.resnet.fc
def features(self, x):
x = forward_resnet_conv(self.resnet, x)
x = self.resnet.avgpool(x).flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.fc(x)
return x
"""
Xception from Kaggle
"""
class XceptionWeiHao(FeatureExtractor):
def __init__(self, n_classes: int, pretrained: bool):
super(XceptionWeiHao, self).__init__()
self.model = get_model("xception", pretrained=pretrained)
self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove original output layer
self.model[0].final_block.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)))
self.model = FCN(self.model, 2048, n_classes)
def features(self, x: torch.Tensor) -> torch.Tensor:
return self.model.base(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
return self.model.h1(x)