IMD / MantraNet /mantranet.py
gagan3012's picture
Add files
627e429
raw history blame
No virus
29.8 kB
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from collections import OrderedDict
# Pytorch
import torch
from torch import nn
import torch.nn.functional as F
# pytorch-lightning
import pytorch_lightning as pl
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##reproduction of the hardsigmoid coded in tensorflow (which is not exactly the same one in Pytorch)
def hardsigmoid(T):
T_0 = T
T = 0.2 * T_0 + 0.5
T[T_0 < -2.5] = 0
T[T_0 > 2.5] = 1
return T
##ConvLSTM - Equivalent implementation of ConvLSTM2d in pytorch
##Source : https://github.com/ndrplz/ConvLSTM_pytorch
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvLSTM cell.
Parameters
----------
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
super(ConvLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.conv = nn.Conv2d(
in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias,
)
self.sigmoid = hardsigmoid
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
combined = torch.cat(
[input_tensor, h_cur], dim=1
) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_c, cc_o = torch.split(combined_conv, self.hidden_dim, dim=1)
i = self.sigmoid(cc_i)
f = self.sigmoid(cc_f)
c_next = f * c_cur + i * torch.tanh(cc_c)
o = self.sigmoid(cc_o)
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (
torch.zeros(
batch_size,
self.hidden_dim,
height,
width,
device=self.conv.weight.device,
),
torch.zeros(
batch_size,
self.hidden_dim,
height,
width,
device=self.conv.weight.device,
),
)
class ConvLSTM(nn.Module):
"""
Parameters:
input_dim: Number of channels in input
hidden_dim: Number of hidden channels
kernel_size: Size of kernel in convolutions
num_layers: Number of LSTM layers stacked on each other
batch_first: Whether or not dimension 0 is the batch or not
bias: Bias or no bias in Convolution
return_all_layers: Return the list of computations for all layers
Note: Will do same padding.
Input:
A tensor of size B, T, C, H, W or T, B, C, H, W
Output:
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
0 - layer_output_list is the list of lists of length T of each output
1 - last_state_list is the list of last states
each element of the list is a tuple (h, c) for hidden state and memory
Example:
>> x = torch.rand((32, 10, 64, 128, 128))
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
>> _, last_states = convlstm(x)
>> h = last_states[0][0] # 0 for layer index, 0 for h index
"""
def __init__(
self,
input_dim,
hidden_dim,
kernel_size,
num_layers,
batch_first=False,
bias=True,
return_all_layers=False,
):
super(ConvLSTM, self).__init__()
self._check_kernel_size_consistency(kernel_size)
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError("Inconsistent list length.")
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(
ConvLSTMCell(
input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias,
)
)
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_tensor, hidden_state=None):
"""
Parameters
----------
input_tensor: todo
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
hidden_state: todo
None. todo implement stateful
Returns
-------
last_state_list, layer_output
"""
if not self.batch_first:
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.transpose(0, 1)
b, _, _, h, w = input_tensor.size()
# Implement stateful ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
# Since the init is done in forward. Can send image size here
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h, c = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
h, c = self.cell_list[layer_idx](
input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
)
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h, c])
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
return init_states
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (
isinstance(kernel_size, tuple)
or (
isinstance(kernel_size, list)
and all([isinstance(elem, tuple) for elem in kernel_size])
)
):
raise ValueError("`kernel_size` must be tuple or list of tuples")
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
class ConvGruCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvGRU cell.
Parameters
----------
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
super(ConvGruCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.sigmoid = hardsigmoid
self.conv1 = nn.Conv2d(
in_channels=self.input_dim + self.hidden_dim,
out_channels=2 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias,
)
self.conv2 = nn.Conv2d(
in_channels=self.input_dim + self.hidden_dim,
out_channels=self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias,
)
def forward(self, input_tensor, cur_state):
h_cur = cur_state
# print(h_cur)
h_x = torch.cat([h_cur, input_tensor], dim=1) # concatenate along channel axis
# print('OK')
combined_conv = self.conv1(h_x)
cc_r, cc_u = torch.split(combined_conv, self.hidden_dim, dim=1)
r = self.sigmoid(cc_r)
u = self.sigmoid(cc_u)
x_r_o_h = torch.cat([input_tensor, r * h_cur], dim=1)
# print(x_r_o_h.size())
combined_conv = self.conv2(x_r_o_h)
c = nn.Tanh()(combined_conv)
h_next = (1 - u) * h_cur + u * c
return h_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
return torch.zeros(
batch_size, self.hidden_dim, height, width, device=self.conv1.weight.device
)
class ConvGRU(nn.Module):
"""
Parameters:
input_dim: Number of channels in input
hidden_dim: Number of hidden channels
kernel_size: Size of kernel in convolutions
num_layers: Number of LSTM layers stacked on each other
batch_first: Whether or not dimension 0 is the batch or not
bias: Bias or no bias in Convolution
return_all_layers: Return the list of computations for all layers
Note: Will do same padding.
Input:
A tensor of size B, T, C, H, W or T, B, C, H, W
Output:
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
0 - layer_output_list is the list of lists of length T of each output
1 - last_state_list is the list of last states
each element of the list is a tuple (h, c) for hidden state and memory
Example:
>> x = torch.rand((32, 10, 64, 128, 128))
>> convgru = ConvGRU(64, 16, 3, 1, True, True, False)
>> _, last_states = convgru(x)
>> h = last_states[0][0] # 0 for layer index, 0 for h index
"""
def __init__(
self,
input_dim,
hidden_dim,
kernel_size,
num_layers,
batch_first=False,
bias=True,
return_all_layers=False,
):
super(ConvGRU, self).__init__()
self._check_kernel_size_consistency(kernel_size)
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError("Inconsistent list length.")
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(
ConvGruCell(
input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias,
)
)
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_tensor, hidden_state=None):
"""
Parameters
----------
input_tensor: todo
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
hidden_state: todo
None. todo implement stateful
Returns
-------
last_state_list, layer_output
"""
if not self.batch_first:
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.transpose(0, 1)
b, _, _, h, w = input_tensor.size()
# Implement stateful ConvGRU
if hidden_state is not None:
raise NotImplementedError()
else:
# Since the init is done in forward. Can send image size here
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
h = self.cell_list[layer_idx](
input_tensor=cur_layer_input[:, t, :, :, :], cur_state=h
)
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append(h)
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
return init_states
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (
isinstance(kernel_size, tuple)
or (
isinstance(kernel_size, list)
and all([isinstance(elem, tuple) for elem in kernel_size])
)
):
raise ValueError("`kernel_size` must be tuple or list of tuples")
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
## Symmetric padding (not existing natively in Pytorch)
## Source : https://discuss.pytorch.org/t/symmetric-padding/19866/3
def reflect(x, minx, maxx):
"""Reflects an array around two points making a triangular waveform that ramps up
and down, allowing for pad lengths greater than the input length"""
rng = maxx - minx
double_rng = 2 * rng
mod = np.fmod(x - minx, double_rng)
normed_mod = np.where(mod < 0, mod + double_rng, mod)
out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx
return np.array(out, dtype=x.dtype)
def symm_pad(im, padding):
h, w = im.shape[-2:]
left, right, top, bottom = padding
x_idx = np.arange(-left, w + right)
y_idx = np.arange(-top, h + bottom)
x_pad = reflect(x_idx, -0.5, w - 0.5)
y_pad = reflect(y_idx, -0.5, h - 0.5)
xx, yy = np.meshgrid(x_pad, y_pad)
return im[..., yy, xx]
# batch normalization equivalent to the one proposed in tensorflow
# Source : https://gluon.mxnet.io/chapter04_convolutional-neural-networks/cnn-batch-norm-scratch.html
def batch_norm(X, eps=0.001):
# extract the dimensions
N, C, H, W = X.shape
device = X.device
# mini-batch mean
mean = X.mean(axis=(0, 2, 3)).to(device)
# mini-batch variance
variance = ((X - mean.view((1, C, 1, 1))) ** 2).mean(axis=(0, 2, 3)).to(device)
# normalize
X = (
(X - mean.reshape((1, C, 1, 1)))
* 1.0
/ torch.pow((variance.view((1, C, 1, 1)) + eps), 0.5)
)
return X.to(device)
# MantraNet (equivalent from the one coded in tensorflow at https://github.com/ISICV/ManTraNet)
class MantraNet(nn.Module):
def __init__(self, in_channel=3, eps=10 ** (-6), device=device):
super(MantraNet, self).__init__()
self.eps = eps
self.relu = nn.ReLU()
self.device = device
# ********** IMAGE MANIPULATION TRACE FEATURE EXTRACTOR *********
## Initialisation
self.init_conv = nn.Conv2d(in_channel, 4, 5, 1, padding=0, bias=False)
self.BayarConv2D = nn.Conv2d(in_channel, 3, 5, 1, padding=0, bias=False)
self.bayar_mask = (torch.tensor(np.ones(shape=(5, 5)))).to(self.device)
self.bayar_mask[2, 2] = 0
self.bayar_final = (torch.tensor(np.zeros((5, 5)))).to(self.device)
self.bayar_final[2, 2] = -1
self.SRMConv2D = nn.Conv2d(in_channel, 9, 5, 1, padding=0, bias=False)
self.SRMConv2D.weight.data = torch.load("MantraNet/MantraNetv4.pt")[
"SRMConv2D.weight"
]
##SRM filters (fixed)
for param in self.SRMConv2D.parameters():
param.requires_grad = False
self.middle_and_last_block = nn.ModuleList(
[
nn.Conv2d(16, 32, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
]
)
# ********** LOCAL ANOMALY DETECTOR *********
self.adaptation = nn.Conv2d(256, 64, 1, 1, padding=0, bias=False)
self.sigma_F = nn.Parameter(torch.zeros((1, 64, 1, 1)), requires_grad=True)
self.pool31 = nn.AvgPool2d(31, stride=1, padding=15, count_include_pad=False)
self.pool15 = nn.AvgPool2d(15, stride=1, padding=7, count_include_pad=False)
self.pool7 = nn.AvgPool2d(7, stride=1, padding=3, count_include_pad=False)
self.convlstm = ConvLSTM(
input_dim=64,
hidden_dim=8,
kernel_size=(7, 7),
num_layers=1,
batch_first=False,
bias=True,
return_all_layers=False,
)
self.end = nn.Sequential(nn.Conv2d(8, 1, 7, 1, padding=3), nn.Sigmoid())
def forward(self, x):
B, nb_channel, H, W = x.shape
if not (self.training):
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
else:
if not hasattr(self, "GlobalPool"):
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
# Normalization
x = x / 255.0 * 2 - 1
## Image Manipulation Trace Feature Extractor
## **Bayar constraints**
self.BayarConv2D.weight.data *= self.bayar_mask
self.BayarConv2D.weight.data *= torch.pow(
self.BayarConv2D.weight.data.sum(axis=(2, 3)).view(3, 3, 1, 1), -1
)
self.BayarConv2D.weight.data += self.bayar_final
# Symmetric padding
x = symm_pad(x, (2, 2, 2, 2))
conv_init = self.init_conv(x)
conv_bayar = self.BayarConv2D(x)
conv_srm = self.SRMConv2D(x)
first_block = torch.cat([conv_init, conv_srm, conv_bayar], axis=1)
first_block = self.relu(first_block)
last_block = first_block
for layer in self.middle_and_last_block:
if isinstance(layer, nn.Conv2d):
last_block = symm_pad(last_block, (1, 1, 1, 1))
last_block = layer(last_block)
# L2 normalization
last_block = F.normalize(last_block, dim=1, p=2)
## Local Anomaly Feature Extraction
X_adapt = self.adaptation(last_block)
X_adapt = batch_norm(X_adapt)
# Z-pool concatenation
mu_T = self.GlobalPool(X_adapt)
sigma_T = torch.sqrt(self.GlobalPool(torch.square(X_adapt - mu_T)))
sigma_T = torch.max(sigma_T, self.sigma_F + self.eps)
inv_sigma_T = torch.pow(sigma_T, -1)
zpoolglobal = torch.abs((mu_T - X_adapt) * inv_sigma_T)
mu_31 = self.pool31(X_adapt)
zpool31 = torch.abs((mu_31 - X_adapt) * inv_sigma_T)
mu_15 = self.pool15(X_adapt)
zpool15 = torch.abs((mu_15 - X_adapt) * inv_sigma_T)
mu_7 = self.pool7(X_adapt)
zpool7 = torch.abs((mu_7 - X_adapt) * inv_sigma_T)
input_lstm = torch.cat(
[
zpool7.unsqueeze(0),
zpool15.unsqueeze(0),
zpool31.unsqueeze(0),
zpoolglobal.unsqueeze(0),
],
axis=0,
)
# Conv2DLSTM
_, output_lstm = self.convlstm(input_lstm)
output_lstm = output_lstm[0][0]
final_output = self.end(output_lstm)
return final_output
# Slight modification of the original MantraNet using a GRU instead of a LSTM
class MantraNet_GRU(nn.Module):
def __init__(self, device, in_channel=3, eps=10 ** (-4)):
super(MantraNet_GRU, self).__init__()
self.eps = eps
self.relu = nn.ReLU()
self.device = device
# ********** IMAGE MANIPULATION TRACE FEATURE EXTRACTOR *********
## Initialisation
self.init_conv = nn.Conv2d(in_channel, 4, 5, 1, padding=0, bias=False)
self.BayarConv2D = nn.Conv2d(in_channel, 3, 5, 1, padding=0, bias=False)
self.SRMConv2D = nn.Conv2d(in_channel, 9, 5, 1, padding=0, bias=False)
self.SRMConv2D.weight.data = torch.load("MantraNetv4.pt")["SRMConv2D.weight"]
##SRM filters (fixed)
for param in self.SRMConv2D.parameters():
param.requires_grad = False
self.middle_and_last_block = nn.ModuleList(
[
nn.Conv2d(16, 32, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, padding=0),
]
)
# ********** LOCAL ANOMALY DETECTOR *********
self.adaptation = nn.Conv2d(256, 64, 1, 1, padding=0, bias=False)
self.sigma_F = nn.Parameter(torch.zeros((1, 64, 1, 1)), requires_grad=True)
self.pool31 = nn.AvgPool2d(31, stride=1, padding=15, count_include_pad=False)
self.pool15 = nn.AvgPool2d(15, stride=1, padding=7, count_include_pad=False)
self.pool7 = nn.AvgPool2d(7, stride=1, padding=3, count_include_pad=False)
self.convgru = ConvGRU(
input_dim=64,
hidden_dim=8,
kernel_size=(7, 7),
num_layers=1,
batch_first=False,
bias=True,
return_all_layers=False,
)
self.end = nn.Sequential(nn.Conv2d(8, 1, 7, 1, padding=3), nn.Sigmoid())
self.bayar_mask = torch.ones((5, 5), device=self.device)
self.bayar_final = torch.zeros((5, 5), device=self.device)
def forward(self, x):
B, nb_channel, H, W = x.shape
if not (self.training):
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
else:
if not hasattr(self, "GlobalPool"):
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
# Normalization
x = x / 255.0 * 2 - 1
## Image Manipulation Trace Feature Extractor
## **Bayar constraints**
self.bayar_mask[2, 2] = 0
self.bayar_final[2, 2] = -1
self.BayarConv2D.weight.data *= self.bayar_mask
self.BayarConv2D.weight.data *= torch.pow(
self.BayarConv2D.weight.data.sum(axis=(2, 3)).view(3, 3, 1, 1), -1
)
self.BayarConv2D.weight.data += self.bayar_final
# Symmetric padding
X = symm_pad(x, (2, 2, 2, 2))
conv_init = self.init_conv(X)
conv_bayar = self.BayarConv2D(X)
conv_srm = self.SRMConv2D(X)
first_block = torch.cat([conv_init, conv_srm, conv_bayar], axis=1)
first_block = self.relu(first_block)
last_block = first_block
for layer in self.middle_and_last_block:
if isinstance(layer, nn.Conv2d):
last_block = symm_pad(last_block, (1, 1, 1, 1))
last_block = layer(last_block)
# L2 normalization
last_block = F.normalize(last_block, dim=1, p=2)
## Local Anomaly Feature Extraction
X_adapt = self.adaptation(last_block)
X_adapt = batch_norm(X_adapt)
# Z-pool concatenation
mu_T = self.GlobalPool(X_adapt)
sigma_T = torch.sqrt(self.GlobalPool(torch.square(X_adapt - mu_T)))
sigma_T = torch.max(sigma_T, self.sigma_F + self.eps)
inv_sigma_T = torch.pow(sigma_T, -1)
zpoolglobal = torch.abs((mu_T - X_adapt) * inv_sigma_T)
mu_31 = self.pool31(X_adapt)
zpool31 = torch.abs((mu_31 - X_adapt) * inv_sigma_T)
mu_15 = self.pool15(X_adapt)
zpool15 = torch.abs((mu_15 - X_adapt) * inv_sigma_T)
mu_7 = self.pool7(X_adapt)
zpool7 = torch.abs((mu_7 - X_adapt) * inv_sigma_T)
input_gru = torch.cat(
[
zpool7.unsqueeze(0),
zpool15.unsqueeze(0),
zpool31.unsqueeze(0),
zpoolglobal.unsqueeze(0),
],
axis=0,
)
# Conv2DLSTM
_, output_gru = self.convgru(input_gru)
output_gru = output_gru[0]
final_output = self.end(output_gru)
return final_output
##Use pre-trained weights :
def pre_trained_model(weight_path="MantraNet\MantraNetv4.pt", device=device):
model = MantraNet(device=device)
model.load_state_dict(torch.load(weight_path))
return model
# predict a forgery mask of an image
def check_forgery(model, img_path="./example.jpg", device=device):
model.to(device)
model.eval()
im = Image.open(img_path)
im = np.array(im)
original_image = im.copy()
im = torch.Tensor(im)
im = im.unsqueeze(0)
im = im.transpose(2, 3).transpose(1, 2)
im = im.to(device)
with torch.no_grad():
final_output = model(im)
fig = plt.figure(figsize=(20, 20))
plt.subplot(1, 3, 1)
plt.imshow(original_image)
plt.title("Original image")
plt.subplot(1, 3, 2)
plt.imshow((final_output[0][0]).cpu().detach(), cmap="gray")
plt.title("Predicted forgery mask")
plt.subplot(1, 3, 3)
plt.imshow(
(final_output[0][0].cpu().detach().unsqueeze(2) > 0.2)
* torch.tensor(original_image)
)
plt.title("Suspicious regions detected")
return fig
class ForgeryDetector(pl.LightningModule):
# Model Initialization/Creation
def __init__(self, train_loader, detector=MantraNet(), lr=0.001):
super(ForgeryDetector, self).__init__()
self.detector = detector
self.train_loader = train_loader
self.cpt = -1
self.lr = lr
# Forward Pass of Model
def forward(self, x):
return self.detector(x)
# Loss Function
def loss(self, y_hat, y):
return nn.BCELoss()(y_hat, y)
# Optimizers
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.detector.parameters(), lr=self.lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
# return the list of optimizers and second empty list is for schedulers (if any)
return [optimizer], []
# Calls after prepare_data for DataLoader
def train_dataloader(self):
return self.train_loader
# Training Loop
def training_step(self, batch, batch_idx):
# batch returns x and y tensors
real_images, mask = batch
B, _, _, _ = real_images.size()
self.cpt += 1
predicted = self.detector(real_images).view(B, -1)
mask = mask.view(B, -1)
loss = self.loss(predicted, mask)
self.log("BCELoss", loss, on_step=True, on_epoch=True, prog_bar=True)
output = OrderedDict(
{
"loss": loss,
}
)
return output