|
""" |
|
This file defines all BusterNet related custom layers |
|
""" |
|
from __future__ import print_function |
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D |
|
from tensorflow.keras.layers import Layer, Input, Lambda |
|
from tensorflow.keras.layers import BatchNormalization, Activation, Concatenate |
|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.applications.vgg16 import preprocess_input |
|
from tensorflow.keras import backend as K |
|
import tensorflow as tf |
|
|
|
|
|
def std_norm_along_chs(x): |
|
"""Data normalization along the channle axis |
|
Input: |
|
x = tensor4d, (n_samples, n_rows, n_cols, n_feats) |
|
Output: |
|
xn = tensor4d, same shape as x, normalized version of x |
|
""" |
|
avg = K.mean(x, axis=-1, keepdims=True) |
|
std = K.maximum(1e-4, K.std(x, axis=-1, keepdims=True)) |
|
return (x - avg) / std |
|
|
|
|
|
def BnInception(x, nb_inc=16, inc_filt_list=[(1, 1), (3, 3), (5, 5)], name="uinc"): |
|
"""Basic Google inception module with batch normalization |
|
Input: |
|
x = tensor4d, (n_samples, n_rows, n_cols, n_feats) |
|
nb_inc = int, number of filters in individual Conv2D |
|
inc_filt_list = list of kernel sizes, individual Conv2D kernel size |
|
name = str, name of module |
|
Output: |
|
xn = tensor4d, (n_samples, n_rows, n_cols, n_new_feats) |
|
""" |
|
uc_list = [] |
|
for idx, ftuple in enumerate(inc_filt_list): |
|
uc = Conv2D( |
|
nb_inc, |
|
ftuple, |
|
activation="linear", |
|
padding="same", |
|
name=name + "_c%d" % idx, |
|
)(x) |
|
uc_list.append(uc) |
|
if len(uc_list) > 1: |
|
uc_merge = Concatenate(axis=-1, name=name + "_merge")(uc_list) |
|
else: |
|
uc_merge = uc_list[0] |
|
uc_norm = BatchNormalization(name=name + "_bn")(uc_merge) |
|
xn = Activation("relu", name=name + "_re")(uc_norm) |
|
return xn |
|
|
|
|
|
class SelfCorrelationPercPooling(Layer): |
|
"""Custom Self-Correlation Percentile Pooling Layer |
|
Arugment: |
|
nb_pools = int, number of percentile poolings |
|
Input: |
|
x = tensor4d, (n_samples, n_rows, n_cols, n_feats) |
|
Output: |
|
x_pool = tensor4d, (n_samples, n_rows, n_cols, nb_pools) |
|
""" |
|
|
|
def __init__(self, nb_pools=256, **kwargs): |
|
self.nb_pools = nb_pools |
|
super(SelfCorrelationPercPooling, self).__init__(**kwargs) |
|
|
|
def build(self, input_shape): |
|
self.built = True |
|
|
|
def call(self, x, mask=None): |
|
|
|
bsize, nb_rows, nb_cols, nb_feats = K.int_shape(x) |
|
nb_maps = nb_rows * nb_cols |
|
|
|
x_3d = K.reshape(x, tf.stack([-1, nb_maps, nb_feats])) |
|
x_corr_3d = ( |
|
tf.matmul(x_3d, x_3d, transpose_a=False, transpose_b=True) / nb_feats |
|
) |
|
x_corr = K.reshape(x_corr_3d, tf.stack([-1, nb_rows, nb_cols, nb_maps])) |
|
|
|
if self.nb_pools is not None: |
|
ranks = K.cast( |
|
K.round(tf.linspace(1.0, nb_maps - 1, self.nb_pools)), "int32" |
|
) |
|
else: |
|
ranks = tf.range(1, nb_maps, dtype="int32") |
|
x_sort, _ = tf.nn.top_k(x_corr, k=nb_maps, sorted=True) |
|
|
|
|
|
x_f1st_sort = K.permute_dimensions(x_sort, (3, 0, 1, 2)) |
|
x_f1st_pool = tf.gather(x_f1st_sort, ranks) |
|
x_pool = K.permute_dimensions(x_f1st_pool, (1, 2, 3, 0)) |
|
return x_pool |
|
|
|
def compute_output_shape(self, input_shape): |
|
bsize, nb_rows, nb_cols, nb_feats = input_shape |
|
nb_pools = ( |
|
self.nb_pools if (self.nb_pools is not None) else (nb_rows * nb_cols - 1) |
|
) |
|
return tuple([bsize, nb_rows, nb_cols, nb_pools]) |
|
|
|
|
|
class BilinearUpSampling2D(Layer): |
|
"""Custom 2x bilinear upsampling layer |
|
Input: |
|
x = tensor4d, (n_samples, n_rows, n_cols, n_feats) |
|
Output: |
|
x2 = tensor4d, (n_samples, 2*n_rows, 2*n_cols, n_feats) |
|
""" |
|
|
|
def call(self, x, mask=None): |
|
bsize, nb_rows, nb_cols, nb_filts = K.int_shape(x) |
|
new_size = tf.constant([nb_rows * 2, nb_cols * 2], dtype=tf.int32) |
|
return tf.image.resize(x, new_size) |
|
|
|
def compute_output_shape(self, input_shape): |
|
bsize, nb_rows, nb_cols, nb_filts = input_shape |
|
return tuple([bsize, nb_rows * 2, nb_cols * 2, nb_filts]) |
|
|
|
|
|
class ResizeBack(Layer): |
|
"""Custom bilinear resize layer |
|
Resize x's spatial dimension to that of r |
|
|
|
Input: |
|
x = tensor4d, (n_samples, n_rowsX, n_colsX, n_featsX ) |
|
r = tensor4d, (n_samples, n_rowsR, n_colsR, n_featsR ) |
|
Output: |
|
xn = tensor4d, (n_samples, n_rowsR, n_colsR, n_featsX ) |
|
""" |
|
|
|
def call(self, x): |
|
t, r = x |
|
new_size = [tf.shape(r)[1], tf.shape(r)[2]] |
|
return tf.image.resize(t, new_size) |
|
|
|
def compute_output_shape(self, input_shapes): |
|
tshape, rshape = input_shapes |
|
return (tshape[0],) + rshape[1:3] + (tshape[-1],) |
|
|
|
|
|
class Preprocess(Layer): |
|
"""Basic preprocess layer for BusterNet |
|
|
|
More precisely, it does the following two things |
|
1) normalize input image size to (256,256) to speed up processing |
|
2) substract channel-wise means if necessary |
|
""" |
|
|
|
def call(self, x, mask=None): |
|
|
|
bsize, nb_rows, nb_cols, nb_colors = K.int_shape(x) |
|
if (nb_rows != 256) or (nb_cols != 256): |
|
|
|
x256 = tf.image.resize(x, [256, 256], name="resize") |
|
else: |
|
x256 = x |
|
|
|
if K.dtype(x) == "float32": |
|
|
|
|
|
xout = x256 |
|
else: |
|
|
|
|
|
xout = preprocess_input(x256) |
|
return xout |
|
|
|
def compute_output_shape(self, input_shape): |
|
return (input_shape[0], 256, 256, 3) |
|
|
|
|
|
def create_cmfd_similarity_branch( |
|
img_shape=(256, 256, 3), nb_pools=100, name="simiDet" |
|
): |
|
"""Create the similarity branch for copy-move forgery detection""" |
|
|
|
|
|
|
|
img_input = Input(shape=img_shape, name=name + "_in") |
|
|
|
|
|
|
|
bname = name + "_cnn" |
|
|
|
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c1")( |
|
img_input |
|
) |
|
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c2")(x1) |
|
x1 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b1p")(x1) |
|
|
|
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c1")( |
|
x1 |
|
) |
|
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c2")( |
|
x2 |
|
) |
|
x2 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b2p")(x2) |
|
|
|
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c1")( |
|
x2 |
|
) |
|
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c2")( |
|
x3 |
|
) |
|
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c3")( |
|
x3 |
|
) |
|
x3 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b3p")(x3) |
|
|
|
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c1")( |
|
x3 |
|
) |
|
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c2")( |
|
x4 |
|
) |
|
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c3")( |
|
x4 |
|
) |
|
x4 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b4p")(x4) |
|
|
|
xx = Activation(std_norm_along_chs, name=bname + "_sn")(x4) |
|
|
|
|
|
|
|
bname = name + "_corr" |
|
|
|
xcorr = SelfCorrelationPercPooling(name=bname + "_corr")(xx) |
|
|
|
xn = BatchNormalization(name=bname + "_bn")(xcorr) |
|
|
|
|
|
|
|
patch_list = [(1, 1), (3, 3), (5, 5)] |
|
|
|
bname = name + "_dconv" |
|
f16 = BnInception(xn, 8, patch_list, name=bname + "_mpf") |
|
|
|
f32 = BilinearUpSampling2D(name=bname + "_bx2")(f16) |
|
dx32 = BnInception(f32, 6, patch_list, name=bname + "_dx2") |
|
|
|
f64a = BilinearUpSampling2D(name=bname + "_bx4a")(f32) |
|
f64b = BilinearUpSampling2D(name=bname + "_bx4b")(dx32) |
|
f64 = Concatenate(axis=-1, name=name + "_dx4_m")([f64a, f64b]) |
|
dx64 = BnInception(f64, 4, patch_list, name=bname + "_dx4") |
|
|
|
f128a = BilinearUpSampling2D(name=bname + "_bx8a")(f64a) |
|
f128b = BilinearUpSampling2D(name=bname + "_bx8b")(dx64) |
|
f128 = Concatenate(axis=-1, name=name + "_dx8_m")([f128a, f128b]) |
|
dx128 = BnInception(f128, 2, patch_list, name=bname + "_dx8") |
|
|
|
f256a = BilinearUpSampling2D(name=bname + "_bx16a")(f128a) |
|
f256b = BilinearUpSampling2D(name=bname + "_bx16b")(dx128) |
|
f256 = Concatenate(axis=-1, name=name + "_dx16_m")([f256a, f256b]) |
|
dx256 = BnInception(f256, 2, patch_list, name=bname + "_dx16") |
|
|
|
fm256 = Concatenate(axis=-1, name=name + "_mfeat")([f256a, dx256]) |
|
masks = BnInception(fm256, 2, [(5, 5), (7, 7), (11, 11)], name=bname + "_dxF") |
|
|
|
|
|
|
|
pred_mask = Conv2D( |
|
1, (3, 3), activation="sigmoid", name=name + "_pred_mask", padding="same" |
|
)(masks) |
|
|
|
|
|
|
|
model = Model(inputs=img_input, outputs=pred_mask, name=name) |
|
return model |
|
|
|
|
|
def create_cmfd_manipulation_branch(img_shape=(256, 256, 3), name="maniDet"): |
|
"""Create the manipulation branch for copy-move forgery detection""" |
|
|
|
|
|
|
|
img_input = Input(shape=img_shape, name=name + "_in") |
|
|
|
|
|
|
|
bname = name + "_cnn" |
|
|
|
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c1")( |
|
img_input |
|
) |
|
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c2")(x1) |
|
x1 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b1p")(x1) |
|
|
|
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c1")( |
|
x1 |
|
) |
|
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c2")( |
|
x2 |
|
) |
|
x2 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b2p")(x2) |
|
|
|
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c1")( |
|
x2 |
|
) |
|
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c2")( |
|
x3 |
|
) |
|
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c3")( |
|
x3 |
|
) |
|
x3 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b3p")(x3) |
|
|
|
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c1")( |
|
x3 |
|
) |
|
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c2")( |
|
x4 |
|
) |
|
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c3")( |
|
x4 |
|
) |
|
x4 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b4p")(x4) |
|
|
|
|
|
|
|
patch_list = [(1, 1), (3, 3), (5, 5)] |
|
bname = name + "_dconv" |
|
|
|
f16 = BnInception(x4, 8, patch_list, name=bname + "_mpf") |
|
|
|
f32 = BilinearUpSampling2D(name=bname + "_bx2")(f16) |
|
dx32 = BnInception(f32, 6, patch_list, name=bname + "_dx2") |
|
|
|
f64 = BilinearUpSampling2D(name=bname + "_bx4")(dx32) |
|
dx64 = BnInception(f64, 4, patch_list, name=bname + "_dx4") |
|
|
|
f128 = BilinearUpSampling2D(name=bname + "_bx8")(dx64) |
|
dx128 = BnInception(f128, 2, patch_list, name=bname + "_dx8") |
|
|
|
f256 = BilinearUpSampling2D(name=bname + "_bx16")(dx128) |
|
dx256 = BnInception(f256, 2, [(5, 5), (7, 7), (11, 11)], name=bname + "_dx16") |
|
|
|
|
|
|
|
pred_mask = Conv2D( |
|
1, (3, 3), activation="sigmoid", name=bname + "_pred_mask", padding="same" |
|
)(dx256) |
|
|
|
|
|
|
|
model = Model(inputs=img_input, outputs=pred_mask, name=bname) |
|
return model |
|
|
|
|
|
def create_BusterNet_testing_model(weight_file=None): |
|
"""create a busterNet testing model with pretrained weights""" |
|
|
|
simi_branch = create_cmfd_similarity_branch() |
|
mani_branch = create_cmfd_manipulation_branch() |
|
|
|
SimiDet = Model( |
|
inputs=simi_branch.inputs, |
|
outputs=simi_branch.layers[-2].output, |
|
name="simiFeatex", |
|
) |
|
ManiDet = Model( |
|
inputs=mani_branch.inputs, |
|
outputs=mani_branch.layers[-2].output, |
|
name="maniFeatex", |
|
) |
|
|
|
|
|
img_raw = Input(shape=(None, None, 3), name="image_in") |
|
img_in = Preprocess(name="preprocess")(img_raw) |
|
|
|
simi_feat = SimiDet(img_in) |
|
mani_feat = ManiDet(img_in) |
|
merged_feat = Concatenate(axis=-1, name="merge")([simi_feat, mani_feat]) |
|
f = BnInception(merged_feat, 3, name="fusion") |
|
mask_out = Conv2D( |
|
3, (3, 3), padding="same", activation="softmax", name="pred_mask" |
|
)(f) |
|
|
|
mask_out = ResizeBack(name="restore")([mask_out, img_raw]) |
|
|
|
model = Model(inputs=img_raw, outputs=mask_out, name="busterNet") |
|
if weight_file is not None: |
|
try: |
|
model.load_weights(weight_file) |
|
print( |
|
"INFO: successfully load pretrained weights from {}".format(weight_file) |
|
) |
|
except Exception as e: |
|
print( |
|
"INFO: fail to load pretrained weights from {} for reason: {}".format( |
|
weight_file, e |
|
) |
|
) |
|
return model |
|
|