Add files
Browse files- .gitattributes +1 -0
- BusterNet/BusterNetCore.py +396 -0
- BusterNet/BusterNetUtils.py +27 -0
- BusterNet/README.md +74 -0
- BusterNet/__pycache__/BusterNetCore.cpython-37.pyc +0 -0
- BusterNet/__pycache__/BusterNetUtils.cpython-37.pyc +0 -0
- BusterNet/pretrained_busterNet.hd5 +3 -0
- MantraNet/MantraNetv4.pt +3 -0
- MantraNet/__init__.py +1 -0
- MantraNet/__pycache__/__init__.cpython-37.pyc +0 -0
- MantraNet/__pycache__/mantranet.cpython-37.pyc +0 -0
- MantraNet/mantranet.py +946 -0
- __init__.py +2 -0
- app.py +37 -0
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.hd5 filter=lfs diff=lfs merge=lfs -text
|
BusterNet/BusterNetCore.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines all BusterNet related custom layers
|
3 |
+
"""
|
4 |
+
from __future__ import print_function
|
5 |
+
from tensorflow.keras.layers import Conv2D, MaxPooling2D
|
6 |
+
from tensorflow.keras.layers import Layer, Input, Lambda
|
7 |
+
from tensorflow.keras.layers import BatchNormalization, Activation, Concatenate
|
8 |
+
from tensorflow.keras.models import Model
|
9 |
+
from tensorflow.keras.applications.vgg16 import preprocess_input
|
10 |
+
from tensorflow.keras import backend as K
|
11 |
+
import tensorflow as tf
|
12 |
+
|
13 |
+
|
14 |
+
def std_norm_along_chs(x):
|
15 |
+
"""Data normalization along the channle axis
|
16 |
+
Input:
|
17 |
+
x = tensor4d, (n_samples, n_rows, n_cols, n_feats)
|
18 |
+
Output:
|
19 |
+
xn = tensor4d, same shape as x, normalized version of x
|
20 |
+
"""
|
21 |
+
avg = K.mean(x, axis=-1, keepdims=True)
|
22 |
+
std = K.maximum(1e-4, K.std(x, axis=-1, keepdims=True))
|
23 |
+
return (x - avg) / std
|
24 |
+
|
25 |
+
|
26 |
+
def BnInception(x, nb_inc=16, inc_filt_list=[(1, 1), (3, 3), (5, 5)], name="uinc"):
|
27 |
+
"""Basic Google inception module with batch normalization
|
28 |
+
Input:
|
29 |
+
x = tensor4d, (n_samples, n_rows, n_cols, n_feats)
|
30 |
+
nb_inc = int, number of filters in individual Conv2D
|
31 |
+
inc_filt_list = list of kernel sizes, individual Conv2D kernel size
|
32 |
+
name = str, name of module
|
33 |
+
Output:
|
34 |
+
xn = tensor4d, (n_samples, n_rows, n_cols, n_new_feats)
|
35 |
+
"""
|
36 |
+
uc_list = []
|
37 |
+
for idx, ftuple in enumerate(inc_filt_list):
|
38 |
+
uc = Conv2D(
|
39 |
+
nb_inc,
|
40 |
+
ftuple,
|
41 |
+
activation="linear",
|
42 |
+
padding="same",
|
43 |
+
name=name + "_c%d" % idx,
|
44 |
+
)(x)
|
45 |
+
uc_list.append(uc)
|
46 |
+
if len(uc_list) > 1:
|
47 |
+
uc_merge = Concatenate(axis=-1, name=name + "_merge")(uc_list)
|
48 |
+
else:
|
49 |
+
uc_merge = uc_list[0]
|
50 |
+
uc_norm = BatchNormalization(name=name + "_bn")(uc_merge)
|
51 |
+
xn = Activation("relu", name=name + "_re")(uc_norm)
|
52 |
+
return xn
|
53 |
+
|
54 |
+
|
55 |
+
class SelfCorrelationPercPooling(Layer):
|
56 |
+
"""Custom Self-Correlation Percentile Pooling Layer
|
57 |
+
Arugment:
|
58 |
+
nb_pools = int, number of percentile poolings
|
59 |
+
Input:
|
60 |
+
x = tensor4d, (n_samples, n_rows, n_cols, n_feats)
|
61 |
+
Output:
|
62 |
+
x_pool = tensor4d, (n_samples, n_rows, n_cols, nb_pools)
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, nb_pools=256, **kwargs):
|
66 |
+
self.nb_pools = nb_pools
|
67 |
+
super(SelfCorrelationPercPooling, self).__init__(**kwargs)
|
68 |
+
|
69 |
+
def build(self, input_shape):
|
70 |
+
self.built = True
|
71 |
+
|
72 |
+
def call(self, x, mask=None):
|
73 |
+
# parse input feature shape
|
74 |
+
bsize, nb_rows, nb_cols, nb_feats = K.int_shape(x)
|
75 |
+
nb_maps = nb_rows * nb_cols
|
76 |
+
# self correlation
|
77 |
+
x_3d = K.reshape(x, tf.stack([-1, nb_maps, nb_feats]))
|
78 |
+
x_corr_3d = (
|
79 |
+
tf.matmul(x_3d, x_3d, transpose_a=False, transpose_b=True) / nb_feats
|
80 |
+
)
|
81 |
+
x_corr = K.reshape(x_corr_3d, tf.stack([-1, nb_rows, nb_cols, nb_maps]))
|
82 |
+
# argsort response maps along the translaton dimension
|
83 |
+
if self.nb_pools is not None:
|
84 |
+
ranks = K.cast(
|
85 |
+
K.round(tf.linspace(1.0, nb_maps - 1, self.nb_pools)), "int32"
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
ranks = tf.range(1, nb_maps, dtype="int32")
|
89 |
+
x_sort, _ = tf.nn.top_k(x_corr, k=nb_maps, sorted=True)
|
90 |
+
# pool out x features at interested ranks
|
91 |
+
# NOTE: tf v1.1 only support indexing at the 1st dimension
|
92 |
+
x_f1st_sort = K.permute_dimensions(x_sort, (3, 0, 1, 2))
|
93 |
+
x_f1st_pool = tf.gather(x_f1st_sort, ranks)
|
94 |
+
x_pool = K.permute_dimensions(x_f1st_pool, (1, 2, 3, 0))
|
95 |
+
return x_pool
|
96 |
+
|
97 |
+
def compute_output_shape(self, input_shape):
|
98 |
+
bsize, nb_rows, nb_cols, nb_feats = input_shape
|
99 |
+
nb_pools = (
|
100 |
+
self.nb_pools if (self.nb_pools is not None) else (nb_rows * nb_cols - 1)
|
101 |
+
)
|
102 |
+
return tuple([bsize, nb_rows, nb_cols, nb_pools])
|
103 |
+
|
104 |
+
|
105 |
+
class BilinearUpSampling2D(Layer):
|
106 |
+
"""Custom 2x bilinear upsampling layer
|
107 |
+
Input:
|
108 |
+
x = tensor4d, (n_samples, n_rows, n_cols, n_feats)
|
109 |
+
Output:
|
110 |
+
x2 = tensor4d, (n_samples, 2*n_rows, 2*n_cols, n_feats)
|
111 |
+
"""
|
112 |
+
|
113 |
+
def call(self, x, mask=None):
|
114 |
+
bsize, nb_rows, nb_cols, nb_filts = K.int_shape(x)
|
115 |
+
new_size = tf.constant([nb_rows * 2, nb_cols * 2], dtype=tf.int32)
|
116 |
+
return tf.image.resize(x, new_size)
|
117 |
+
|
118 |
+
def compute_output_shape(self, input_shape):
|
119 |
+
bsize, nb_rows, nb_cols, nb_filts = input_shape
|
120 |
+
return tuple([bsize, nb_rows * 2, nb_cols * 2, nb_filts])
|
121 |
+
|
122 |
+
|
123 |
+
class ResizeBack(Layer):
|
124 |
+
"""Custom bilinear resize layer
|
125 |
+
Resize x's spatial dimension to that of r
|
126 |
+
|
127 |
+
Input:
|
128 |
+
x = tensor4d, (n_samples, n_rowsX, n_colsX, n_featsX )
|
129 |
+
r = tensor4d, (n_samples, n_rowsR, n_colsR, n_featsR )
|
130 |
+
Output:
|
131 |
+
xn = tensor4d, (n_samples, n_rowsR, n_colsR, n_featsX )
|
132 |
+
"""
|
133 |
+
|
134 |
+
def call(self, x):
|
135 |
+
t, r = x
|
136 |
+
new_size = [tf.shape(r)[1], tf.shape(r)[2]]
|
137 |
+
return tf.image.resize(t, new_size)
|
138 |
+
|
139 |
+
def compute_output_shape(self, input_shapes):
|
140 |
+
tshape, rshape = input_shapes
|
141 |
+
return (tshape[0],) + rshape[1:3] + (tshape[-1],)
|
142 |
+
|
143 |
+
|
144 |
+
class Preprocess(Layer):
|
145 |
+
"""Basic preprocess layer for BusterNet
|
146 |
+
|
147 |
+
More precisely, it does the following two things
|
148 |
+
1) normalize input image size to (256,256) to speed up processing
|
149 |
+
2) substract channel-wise means if necessary
|
150 |
+
"""
|
151 |
+
|
152 |
+
def call(self, x, mask=None):
|
153 |
+
# parse input image shape
|
154 |
+
bsize, nb_rows, nb_cols, nb_colors = K.int_shape(x)
|
155 |
+
if (nb_rows != 256) or (nb_cols != 256):
|
156 |
+
# resize image if different from (256,256)
|
157 |
+
x256 = tf.image.resize(x, [256, 256], name="resize")
|
158 |
+
else:
|
159 |
+
x256 = x
|
160 |
+
# substract channel means if necessary
|
161 |
+
if K.dtype(x) == "float32":
|
162 |
+
# input is not a 'uint8' image
|
163 |
+
# assume it has already been normalized
|
164 |
+
xout = x256
|
165 |
+
else:
|
166 |
+
# input is a 'uint8' image
|
167 |
+
# substract channel-wise means
|
168 |
+
xout = preprocess_input(x256)
|
169 |
+
return xout
|
170 |
+
|
171 |
+
def compute_output_shape(self, input_shape):
|
172 |
+
return (input_shape[0], 256, 256, 3)
|
173 |
+
|
174 |
+
|
175 |
+
def create_cmfd_similarity_branch(
|
176 |
+
img_shape=(256, 256, 3), nb_pools=100, name="simiDet"
|
177 |
+
):
|
178 |
+
"""Create the similarity branch for copy-move forgery detection"""
|
179 |
+
# ---------------------------------------------------------
|
180 |
+
# Input
|
181 |
+
# ---------------------------------------------------------
|
182 |
+
img_input = Input(shape=img_shape, name=name + "_in")
|
183 |
+
# ---------------------------------------------------------
|
184 |
+
# VGG16 Conv Featex
|
185 |
+
# ---------------------------------------------------------
|
186 |
+
bname = name + "_cnn"
|
187 |
+
## Block 1
|
188 |
+
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c1")(
|
189 |
+
img_input
|
190 |
+
)
|
191 |
+
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c2")(x1)
|
192 |
+
x1 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b1p")(x1)
|
193 |
+
# Block 2
|
194 |
+
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c1")(
|
195 |
+
x1
|
196 |
+
)
|
197 |
+
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c2")(
|
198 |
+
x2
|
199 |
+
)
|
200 |
+
x2 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b2p")(x2)
|
201 |
+
# Block 3
|
202 |
+
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c1")(
|
203 |
+
x2
|
204 |
+
)
|
205 |
+
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c2")(
|
206 |
+
x3
|
207 |
+
)
|
208 |
+
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c3")(
|
209 |
+
x3
|
210 |
+
)
|
211 |
+
x3 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b3p")(x3)
|
212 |
+
# Block 4
|
213 |
+
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c1")(
|
214 |
+
x3
|
215 |
+
)
|
216 |
+
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c2")(
|
217 |
+
x4
|
218 |
+
)
|
219 |
+
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c3")(
|
220 |
+
x4
|
221 |
+
)
|
222 |
+
x4 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b4p")(x4)
|
223 |
+
# Local Std-Norm Normalization (within each sample)
|
224 |
+
xx = Activation(std_norm_along_chs, name=bname + "_sn")(x4)
|
225 |
+
# ---------------------------------------------------------
|
226 |
+
# Self Correlation Pooling
|
227 |
+
# ---------------------------------------------------------
|
228 |
+
bname = name + "_corr"
|
229 |
+
## Self Correlation
|
230 |
+
xcorr = SelfCorrelationPercPooling(name=bname + "_corr")(xx)
|
231 |
+
## Global Batch Normalization (across samples)
|
232 |
+
xn = BatchNormalization(name=bname + "_bn")(xcorr)
|
233 |
+
# ---------------------------------------------------------
|
234 |
+
# Deconvolution Network
|
235 |
+
# ---------------------------------------------------------
|
236 |
+
patch_list = [(1, 1), (3, 3), (5, 5)]
|
237 |
+
# MultiPatch Featex
|
238 |
+
bname = name + "_dconv"
|
239 |
+
f16 = BnInception(xn, 8, patch_list, name=bname + "_mpf")
|
240 |
+
# Deconv x2
|
241 |
+
f32 = BilinearUpSampling2D(name=bname + "_bx2")(f16)
|
242 |
+
dx32 = BnInception(f32, 6, patch_list, name=bname + "_dx2")
|
243 |
+
# Deconv x4
|
244 |
+
f64a = BilinearUpSampling2D(name=bname + "_bx4a")(f32)
|
245 |
+
f64b = BilinearUpSampling2D(name=bname + "_bx4b")(dx32)
|
246 |
+
f64 = Concatenate(axis=-1, name=name + "_dx4_m")([f64a, f64b])
|
247 |
+
dx64 = BnInception(f64, 4, patch_list, name=bname + "_dx4")
|
248 |
+
# Deconv x8
|
249 |
+
f128a = BilinearUpSampling2D(name=bname + "_bx8a")(f64a)
|
250 |
+
f128b = BilinearUpSampling2D(name=bname + "_bx8b")(dx64)
|
251 |
+
f128 = Concatenate(axis=-1, name=name + "_dx8_m")([f128a, f128b])
|
252 |
+
dx128 = BnInception(f128, 2, patch_list, name=bname + "_dx8")
|
253 |
+
# Deconv x16
|
254 |
+
f256a = BilinearUpSampling2D(name=bname + "_bx16a")(f128a)
|
255 |
+
f256b = BilinearUpSampling2D(name=bname + "_bx16b")(dx128)
|
256 |
+
f256 = Concatenate(axis=-1, name=name + "_dx16_m")([f256a, f256b])
|
257 |
+
dx256 = BnInception(f256, 2, patch_list, name=bname + "_dx16")
|
258 |
+
# Summerize
|
259 |
+
fm256 = Concatenate(axis=-1, name=name + "_mfeat")([f256a, dx256])
|
260 |
+
masks = BnInception(fm256, 2, [(5, 5), (7, 7), (11, 11)], name=bname + "_dxF")
|
261 |
+
# ---------------------------------------------------------
|
262 |
+
# Output for Auxiliary Task
|
263 |
+
# ---------------------------------------------------------
|
264 |
+
pred_mask = Conv2D(
|
265 |
+
1, (3, 3), activation="sigmoid", name=name + "_pred_mask", padding="same"
|
266 |
+
)(masks)
|
267 |
+
# ---------------------------------------------------------
|
268 |
+
# End to End
|
269 |
+
# ---------------------------------------------------------
|
270 |
+
model = Model(inputs=img_input, outputs=pred_mask, name=name)
|
271 |
+
return model
|
272 |
+
|
273 |
+
|
274 |
+
def create_cmfd_manipulation_branch(img_shape=(256, 256, 3), name="maniDet"):
|
275 |
+
"""Create the manipulation branch for copy-move forgery detection"""
|
276 |
+
# ---------------------------------------------------------
|
277 |
+
# Input
|
278 |
+
# ---------------------------------------------------------
|
279 |
+
img_input = Input(shape=img_shape, name=name + "_in")
|
280 |
+
# ---------------------------------------------------------
|
281 |
+
# VGG16 Conv Featex
|
282 |
+
# ---------------------------------------------------------
|
283 |
+
bname = name + "_cnn"
|
284 |
+
# Block 1
|
285 |
+
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c1")(
|
286 |
+
img_input
|
287 |
+
)
|
288 |
+
x1 = Conv2D(64, (3, 3), activation="relu", padding="same", name=bname + "_b1c2")(x1)
|
289 |
+
x1 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b1p")(x1)
|
290 |
+
# Block 2
|
291 |
+
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c1")(
|
292 |
+
x1
|
293 |
+
)
|
294 |
+
x2 = Conv2D(128, (3, 3), activation="relu", padding="same", name=bname + "_b2c2")(
|
295 |
+
x2
|
296 |
+
)
|
297 |
+
x2 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b2p")(x2)
|
298 |
+
# Block 3
|
299 |
+
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c1")(
|
300 |
+
x2
|
301 |
+
)
|
302 |
+
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c2")(
|
303 |
+
x3
|
304 |
+
)
|
305 |
+
x3 = Conv2D(256, (3, 3), activation="relu", padding="same", name=bname + "_b3c3")(
|
306 |
+
x3
|
307 |
+
)
|
308 |
+
x3 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b3p")(x3)
|
309 |
+
# Block 4
|
310 |
+
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c1")(
|
311 |
+
x3
|
312 |
+
)
|
313 |
+
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c2")(
|
314 |
+
x4
|
315 |
+
)
|
316 |
+
x4 = Conv2D(512, (3, 3), activation="relu", padding="same", name=bname + "_b4c3")(
|
317 |
+
x4
|
318 |
+
)
|
319 |
+
x4 = MaxPooling2D((2, 2), strides=(2, 2), name=bname + "_b4p")(x4)
|
320 |
+
# ---------------------------------------------------------
|
321 |
+
# Deconvolution Network
|
322 |
+
# ---------------------------------------------------------
|
323 |
+
patch_list = [(1, 1), (3, 3), (5, 5)]
|
324 |
+
bname = name + "_dconv"
|
325 |
+
# MultiPatch Featex
|
326 |
+
f16 = BnInception(x4, 8, patch_list, name=bname + "_mpf")
|
327 |
+
# Deconv x2
|
328 |
+
f32 = BilinearUpSampling2D(name=bname + "_bx2")(f16)
|
329 |
+
dx32 = BnInception(f32, 6, patch_list, name=bname + "_dx2")
|
330 |
+
# Deconv x4
|
331 |
+
f64 = BilinearUpSampling2D(name=bname + "_bx4")(dx32)
|
332 |
+
dx64 = BnInception(f64, 4, patch_list, name=bname + "_dx4")
|
333 |
+
# Deconv x8
|
334 |
+
f128 = BilinearUpSampling2D(name=bname + "_bx8")(dx64)
|
335 |
+
dx128 = BnInception(f128, 2, patch_list, name=bname + "_dx8")
|
336 |
+
# Deconv x16
|
337 |
+
f256 = BilinearUpSampling2D(name=bname + "_bx16")(dx128)
|
338 |
+
dx256 = BnInception(f256, 2, [(5, 5), (7, 7), (11, 11)], name=bname + "_dx16")
|
339 |
+
# ---------------------------------------------------------
|
340 |
+
# Output for Auxiliary Task
|
341 |
+
# ---------------------------------------------------------
|
342 |
+
pred_mask = Conv2D(
|
343 |
+
1, (3, 3), activation="sigmoid", name=bname + "_pred_mask", padding="same"
|
344 |
+
)(dx256)
|
345 |
+
# ---------------------------------------------------------
|
346 |
+
# End to End
|
347 |
+
# ---------------------------------------------------------
|
348 |
+
model = Model(inputs=img_input, outputs=pred_mask, name=bname)
|
349 |
+
return model
|
350 |
+
|
351 |
+
|
352 |
+
def create_BusterNet_testing_model(weight_file=None):
|
353 |
+
"""create a busterNet testing model with pretrained weights"""
|
354 |
+
# 1. create branch model
|
355 |
+
simi_branch = create_cmfd_similarity_branch()
|
356 |
+
mani_branch = create_cmfd_manipulation_branch()
|
357 |
+
# 2. crop off the last auxiliary task layer
|
358 |
+
SimiDet = Model(
|
359 |
+
inputs=simi_branch.inputs,
|
360 |
+
outputs=simi_branch.layers[-2].output,
|
361 |
+
name="simiFeatex",
|
362 |
+
)
|
363 |
+
ManiDet = Model(
|
364 |
+
inputs=mani_branch.inputs,
|
365 |
+
outputs=mani_branch.layers[-2].output,
|
366 |
+
name="maniFeatex",
|
367 |
+
)
|
368 |
+
# 3. define the two-branch BusterNet model
|
369 |
+
# 3.a define wrapper inputs
|
370 |
+
img_raw = Input(shape=(None, None, 3), name="image_in")
|
371 |
+
img_in = Preprocess(name="preprocess")(img_raw)
|
372 |
+
# 3.b define BusterNet Core
|
373 |
+
simi_feat = SimiDet(img_in)
|
374 |
+
mani_feat = ManiDet(img_in)
|
375 |
+
merged_feat = Concatenate(axis=-1, name="merge")([simi_feat, mani_feat])
|
376 |
+
f = BnInception(merged_feat, 3, name="fusion")
|
377 |
+
mask_out = Conv2D(
|
378 |
+
3, (3, 3), padding="same", activation="softmax", name="pred_mask"
|
379 |
+
)(f)
|
380 |
+
# 3.c define wrapper output
|
381 |
+
mask_out = ResizeBack(name="restore")([mask_out, img_raw])
|
382 |
+
# 4. create BusterNet model end-to-end
|
383 |
+
model = Model(inputs=img_raw, outputs=mask_out, name="busterNet")
|
384 |
+
if weight_file is not None:
|
385 |
+
try:
|
386 |
+
model.load_weights(weight_file)
|
387 |
+
print(
|
388 |
+
"INFO: successfully load pretrained weights from {}".format(weight_file)
|
389 |
+
)
|
390 |
+
except Exception as e:
|
391 |
+
print(
|
392 |
+
"INFO: fail to load pretrained weights from {} for reason: {}".format(
|
393 |
+
weight_file, e
|
394 |
+
)
|
395 |
+
)
|
396 |
+
return model
|
BusterNet/BusterNetUtils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
|
5 |
+
def simple_cmfd_decoder(busterNetModel, rgb):
|
6 |
+
"""A simple BusterNet CMFD decoder"""
|
7 |
+
# 1. expand an image to a single sample batch
|
8 |
+
single_sample_batch = np.expand_dims(rgb, axis=0)
|
9 |
+
# 2. perform busterNet CMFD
|
10 |
+
pred = busterNetModel.predict(single_sample_batch)[0]
|
11 |
+
return pred
|
12 |
+
|
13 |
+
|
14 |
+
def visualize_result(rgb, gt, pred, figsize=(12, 4), title=None):
|
15 |
+
"""Visualize raw input, ground truth, and BusterNet result"""
|
16 |
+
fig = plt.figure(figsize=figsize)
|
17 |
+
|
18 |
+
plt.subplot(1, 3, 1)
|
19 |
+
plt.imshow(rgb)
|
20 |
+
plt.title("input image")
|
21 |
+
plt.subplot(1, 3, 2)
|
22 |
+
plt.title("ground truth")
|
23 |
+
plt.imshow(gt)
|
24 |
+
plt.subplot(1, 3, 3)
|
25 |
+
plt.imshow(pred)
|
26 |
+
plt.title("busterNet pred")
|
27 |
+
return fig
|
BusterNet/README.md
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BusterNet: Detecting Copy-Move Image Forgery with Source/Target Localization
|
2 |
+
|
3 |
+
### Introduction
|
4 |
+
We introduce a novel deep neural architecture for image copy-move forgery detection (CMFD), code-named *BusterNet*. Unlike previous efforts, BusterNet is a pure, end-to-end trainable, deep neural network solution. It features a two-branch architecture followed by a fusion module. The two branches localize potential manipulation regions via visual artifacts and copy-move regions via visual similarities, respectively. To the best of our knowledge, this is the first CMFD algorithm with discernibility to localize source/target regions.
|
5 |
+
|
6 |
+
In this repository, we release many paper related things, including
|
7 |
+
|
8 |
+
- a pretrained BusterNet model
|
9 |
+
- custom layers implemented in keras-tensorflow
|
10 |
+
- CASIA-CMFD, CoMoFoD-CMFD, and USCISI-CMFD dataset
|
11 |
+
- python notebook to reproduce paper results
|
12 |
+
|
13 |
+
### Repo Organization
|
14 |
+
The entire repo is organized as follows:
|
15 |
+
|
16 |
+
- **Data** - host all datasets
|
17 |
+
- *CASIA-CMFD
|
18 |
+
- *CoMoFoD-CMFD
|
19 |
+
- *USCISI-CMFD
|
20 |
+
- **Model** - host all model files
|
21 |
+
- **ReadMe.md** - this file
|
22 |
+
|
23 |
+
Due to the size limit, we can't host all dataset in repo. For those large ones, we host them externally. *indicated dataset requires to be downloaded seperately. Please refer to the document of each dataset for more detailed downloading instructions.
|
24 |
+
|
25 |
+
### Python/Keras/Tensorflow
|
26 |
+
The original model was trained with
|
27 |
+
|
28 |
+
- keras.version = 2.0.7
|
29 |
+
- tensorflow.version = 1.1.0
|
30 |
+
|
31 |
+
we also test the repository with
|
32 |
+
|
33 |
+
- keras.version = 2.2.2
|
34 |
+
- tensorflow.version = 1.8.0
|
35 |
+
|
36 |
+
Though small differences may be found, results are in general consistent.
|
37 |
+
|
38 |
+
### Citation
|
39 |
+
If you use the provided code or data in any publication, please kindly cite the following paper.
|
40 |
+
|
41 |
+
@inproceedings{wu2018eccv,
|
42 |
+
title={BusterNet: Detecting Image Copy-Move Forgery With Source/Target Localization},
|
43 |
+
author={Wu, Yue, and AbdAlmageed, Wael and Natarajan, Prem},
|
44 |
+
booktitle={European Conference on Computer Vision (ECCV)},
|
45 |
+
year={2018},
|
46 |
+
organization={Springer},
|
47 |
+
}
|
48 |
+
|
49 |
+
### Contact
|
50 |
+
- Name: Yue Wu
|
51 |
+
- Email: yue_wu\[at\]isi.edu
|
52 |
+
|
53 |
+
|
54 |
+
### License
|
55 |
+
The Software is made available for academic or non-commercial purposes only. The license is for a copy of the program for an unlimited term. Individuals requesting a license for commercial use must pay for a commercial license.
|
56 |
+
|
57 |
+
USC Stevens Institute for Innovation
|
58 |
+
University of Southern California
|
59 |
+
1150 S. Olive Street, Suite 2300
|
60 |
+
Los Angeles, CA 90115, USA
|
61 |
+
ATTN: Accounting
|
62 |
+
|
63 |
+
DISCLAIMER. USC MAKES NO EXPRESS OR IMPLIED WARRANTIES, EITHER IN FACT OR BY OPERATION OF LAW, BY STATUTE OR OTHERWISE, AND USC SPECIFICALLY AND EXPRESSLY DISCLAIMS ANY EXPRESS OR IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, VALIDITY OF THE SOFTWARE OR ANY OTHER INTELLECTUAL PROPERTY RIGHTS OR NON-INFRINGEMENT OF THE INTELLECTUAL PROPERTY OR OTHER RIGHTS OF ANY THIRD PARTY. SOFTWARE IS MADE AVAILABLE AS-IS. LIMITATION OF LIABILITY. TO THE MAXIMUM EXTENT PERMITTED BY LAW, IN NO EVENT WILL USC BE LIABLE TO ANY USER OF THIS CODE FOR ANY INCIDENTAL, CONSEQUENTIAL, EXEMPLARY OR PUNITIVE DAMAGES OF ANY KIND, LOST GOODWILL, LOST PROFITS, LOST BUSINESS AND/OR ANY INDIRECT ECONOMIC DAMAGES WHATSOEVER, REGARDLESS OF WHETHER SUCH DAMAGES ARISE FROM CLAIMS BASED UPON CONTRACT, NEGLIGENCE, TORT (INCLUDING STRICT LIABILITY OR OTHER LEGAL THEORY), A BREACH OF ANY WARRANTY OR TERM OF THIS AGREEMENT, AND REGARDLESS OF WHETHER USC WAS ADVISED OR HAD REASON TO KNOW OF THE POSSIBILITY OF INCURRING SUCH DAMAGES IN ADVANCE.
|
64 |
+
|
65 |
+
For commercial license pricing and annual commercial update and support pricing, please contact:
|
66 |
+
|
67 |
+
Rakesh Pandit USC Stevens Institute for Innovation
|
68 |
+
University of Southern California
|
69 |
+
1150 S. Olive Street, Suite 2300
|
70 |
+
Los Angeles, CA 90115, USA
|
71 |
+
|
72 |
+
Tel: +1 213-821-3552
|
73 |
+
Fax: +1 213-821-5001
|
74 |
+
Email: rakeshvp@usc.edu and ccto: accounting@stevens.usc.edu
|
BusterNet/__pycache__/BusterNetCore.cpython-37.pyc
ADDED
Binary file (10.8 kB). View file
|
BusterNet/__pycache__/BusterNetUtils.cpython-37.pyc
ADDED
Binary file (940 Bytes). View file
|
BusterNet/pretrained_busterNet.hd5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd6b4628cf302e12bbe7092f5c1ed3f145d13b345930a588b5326d100207468f
|
3 |
+
size 62258088
|
MantraNet/MantraNetv4.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69835136cec7f87d89820cdad8a21d9c4f1700ee8074c345cdc2289c1e1099e4
|
3 |
+
size 15234840
|
MantraNet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
1 |
+
from MantraNet.mantranet import pre_trained_model, check_forgery
|
MantraNet/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (248 Bytes). View file
|
MantraNet/__pycache__/mantranet.cpython-37.pyc
ADDED
Binary file (20.9 kB). View file
|
MantraNet/mantranet.py
ADDED
@@ -0,0 +1,946 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from PIL import Image
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
# Pytorch
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# pytorch-lightning
|
13 |
+
import pytorch_lightning as pl
|
14 |
+
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
|
18 |
+
##reproduction of the hardsigmoid coded in tensorflow (which is not exactly the same one in Pytorch)
|
19 |
+
def hardsigmoid(T):
|
20 |
+
T_0 = T
|
21 |
+
T = 0.2 * T_0 + 0.5
|
22 |
+
T[T_0 < -2.5] = 0
|
23 |
+
T[T_0 > 2.5] = 1
|
24 |
+
|
25 |
+
return T
|
26 |
+
|
27 |
+
|
28 |
+
##ConvLSTM - Equivalent implementation of ConvLSTM2d in pytorch
|
29 |
+
##Source : https://github.com/ndrplz/ConvLSTM_pytorch
|
30 |
+
class ConvLSTMCell(nn.Module):
|
31 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
|
32 |
+
"""
|
33 |
+
Initialize ConvLSTM cell.
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
input_dim: int
|
37 |
+
Number of channels of input tensor.
|
38 |
+
hidden_dim: int
|
39 |
+
Number of channels of hidden state.
|
40 |
+
kernel_size: (int, int)
|
41 |
+
Size of the convolutional kernel.
|
42 |
+
bias: bool
|
43 |
+
Whether or not to add the bias.
|
44 |
+
"""
|
45 |
+
|
46 |
+
super(ConvLSTMCell, self).__init__()
|
47 |
+
|
48 |
+
self.input_dim = input_dim
|
49 |
+
self.hidden_dim = hidden_dim
|
50 |
+
|
51 |
+
self.kernel_size = kernel_size
|
52 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
53 |
+
self.bias = bias
|
54 |
+
|
55 |
+
self.conv = nn.Conv2d(
|
56 |
+
in_channels=self.input_dim + self.hidden_dim,
|
57 |
+
out_channels=4 * self.hidden_dim,
|
58 |
+
kernel_size=self.kernel_size,
|
59 |
+
padding=self.padding,
|
60 |
+
bias=self.bias,
|
61 |
+
)
|
62 |
+
|
63 |
+
self.sigmoid = hardsigmoid
|
64 |
+
|
65 |
+
def forward(self, input_tensor, cur_state):
|
66 |
+
h_cur, c_cur = cur_state
|
67 |
+
|
68 |
+
combined = torch.cat(
|
69 |
+
[input_tensor, h_cur], dim=1
|
70 |
+
) # concatenate along channel axis
|
71 |
+
|
72 |
+
combined_conv = self.conv(combined)
|
73 |
+
cc_i, cc_f, cc_c, cc_o = torch.split(combined_conv, self.hidden_dim, dim=1)
|
74 |
+
i = self.sigmoid(cc_i)
|
75 |
+
f = self.sigmoid(cc_f)
|
76 |
+
c_next = f * c_cur + i * torch.tanh(cc_c)
|
77 |
+
o = self.sigmoid(cc_o)
|
78 |
+
|
79 |
+
h_next = o * torch.tanh(c_next)
|
80 |
+
|
81 |
+
return h_next, c_next
|
82 |
+
|
83 |
+
def init_hidden(self, batch_size, image_size):
|
84 |
+
height, width = image_size
|
85 |
+
return (
|
86 |
+
torch.zeros(
|
87 |
+
batch_size,
|
88 |
+
self.hidden_dim,
|
89 |
+
height,
|
90 |
+
width,
|
91 |
+
device=self.conv.weight.device,
|
92 |
+
),
|
93 |
+
torch.zeros(
|
94 |
+
batch_size,
|
95 |
+
self.hidden_dim,
|
96 |
+
height,
|
97 |
+
width,
|
98 |
+
device=self.conv.weight.device,
|
99 |
+
),
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class ConvLSTM(nn.Module):
|
104 |
+
"""
|
105 |
+
|
106 |
+
Parameters:
|
107 |
+
input_dim: Number of channels in input
|
108 |
+
hidden_dim: Number of hidden channels
|
109 |
+
kernel_size: Size of kernel in convolutions
|
110 |
+
num_layers: Number of LSTM layers stacked on each other
|
111 |
+
batch_first: Whether or not dimension 0 is the batch or not
|
112 |
+
bias: Bias or no bias in Convolution
|
113 |
+
return_all_layers: Return the list of computations for all layers
|
114 |
+
Note: Will do same padding.
|
115 |
+
|
116 |
+
Input:
|
117 |
+
A tensor of size B, T, C, H, W or T, B, C, H, W
|
118 |
+
Output:
|
119 |
+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
120 |
+
0 - layer_output_list is the list of lists of length T of each output
|
121 |
+
1 - last_state_list is the list of last states
|
122 |
+
each element of the list is a tuple (h, c) for hidden state and memory
|
123 |
+
Example:
|
124 |
+
>> x = torch.rand((32, 10, 64, 128, 128))
|
125 |
+
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
|
126 |
+
>> _, last_states = convlstm(x)
|
127 |
+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
input_dim,
|
133 |
+
hidden_dim,
|
134 |
+
kernel_size,
|
135 |
+
num_layers,
|
136 |
+
batch_first=False,
|
137 |
+
bias=True,
|
138 |
+
return_all_layers=False,
|
139 |
+
):
|
140 |
+
super(ConvLSTM, self).__init__()
|
141 |
+
|
142 |
+
self._check_kernel_size_consistency(kernel_size)
|
143 |
+
|
144 |
+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
145 |
+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
|
146 |
+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
147 |
+
if not len(kernel_size) == len(hidden_dim) == num_layers:
|
148 |
+
raise ValueError("Inconsistent list length.")
|
149 |
+
|
150 |
+
self.input_dim = input_dim
|
151 |
+
self.hidden_dim = hidden_dim
|
152 |
+
self.kernel_size = kernel_size
|
153 |
+
self.num_layers = num_layers
|
154 |
+
self.batch_first = batch_first
|
155 |
+
self.bias = bias
|
156 |
+
self.return_all_layers = return_all_layers
|
157 |
+
|
158 |
+
cell_list = []
|
159 |
+
for i in range(0, self.num_layers):
|
160 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
161 |
+
|
162 |
+
cell_list.append(
|
163 |
+
ConvLSTMCell(
|
164 |
+
input_dim=cur_input_dim,
|
165 |
+
hidden_dim=self.hidden_dim[i],
|
166 |
+
kernel_size=self.kernel_size[i],
|
167 |
+
bias=self.bias,
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
self.cell_list = nn.ModuleList(cell_list)
|
172 |
+
|
173 |
+
def forward(self, input_tensor, hidden_state=None):
|
174 |
+
"""
|
175 |
+
|
176 |
+
Parameters
|
177 |
+
----------
|
178 |
+
input_tensor: todo
|
179 |
+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
|
180 |
+
hidden_state: todo
|
181 |
+
None. todo implement stateful
|
182 |
+
|
183 |
+
Returns
|
184 |
+
-------
|
185 |
+
last_state_list, layer_output
|
186 |
+
"""
|
187 |
+
if not self.batch_first:
|
188 |
+
# (t, b, c, h, w) -> (b, t, c, h, w)
|
189 |
+
input_tensor = input_tensor.transpose(0, 1)
|
190 |
+
|
191 |
+
b, _, _, h, w = input_tensor.size()
|
192 |
+
|
193 |
+
# Implement stateful ConvLSTM
|
194 |
+
if hidden_state is not None:
|
195 |
+
raise NotImplementedError()
|
196 |
+
else:
|
197 |
+
# Since the init is done in forward. Can send image size here
|
198 |
+
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
|
199 |
+
|
200 |
+
layer_output_list = []
|
201 |
+
last_state_list = []
|
202 |
+
|
203 |
+
seq_len = input_tensor.size(1)
|
204 |
+
cur_layer_input = input_tensor
|
205 |
+
|
206 |
+
for layer_idx in range(self.num_layers):
|
207 |
+
|
208 |
+
h, c = hidden_state[layer_idx]
|
209 |
+
output_inner = []
|
210 |
+
for t in range(seq_len):
|
211 |
+
h, c = self.cell_list[layer_idx](
|
212 |
+
input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
|
213 |
+
)
|
214 |
+
output_inner.append(h)
|
215 |
+
|
216 |
+
layer_output = torch.stack(output_inner, dim=1)
|
217 |
+
cur_layer_input = layer_output
|
218 |
+
|
219 |
+
layer_output_list.append(layer_output)
|
220 |
+
last_state_list.append([h, c])
|
221 |
+
|
222 |
+
if not self.return_all_layers:
|
223 |
+
layer_output_list = layer_output_list[-1:]
|
224 |
+
last_state_list = last_state_list[-1:]
|
225 |
+
|
226 |
+
return layer_output_list, last_state_list
|
227 |
+
|
228 |
+
def _init_hidden(self, batch_size, image_size):
|
229 |
+
init_states = []
|
230 |
+
for i in range(self.num_layers):
|
231 |
+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
|
232 |
+
return init_states
|
233 |
+
|
234 |
+
@staticmethod
|
235 |
+
def _check_kernel_size_consistency(kernel_size):
|
236 |
+
if not (
|
237 |
+
isinstance(kernel_size, tuple)
|
238 |
+
or (
|
239 |
+
isinstance(kernel_size, list)
|
240 |
+
and all([isinstance(elem, tuple) for elem in kernel_size])
|
241 |
+
)
|
242 |
+
):
|
243 |
+
raise ValueError("`kernel_size` must be tuple or list of tuples")
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def _extend_for_multilayer(param, num_layers):
|
247 |
+
if not isinstance(param, list):
|
248 |
+
param = [param] * num_layers
|
249 |
+
return param
|
250 |
+
|
251 |
+
|
252 |
+
class ConvGruCell(nn.Module):
|
253 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
|
254 |
+
"""
|
255 |
+
Initialize ConvGRU cell.
|
256 |
+
Parameters
|
257 |
+
----------
|
258 |
+
input_dim: int
|
259 |
+
Number of channels of input tensor.
|
260 |
+
hidden_dim: int
|
261 |
+
Number of channels of hidden state.
|
262 |
+
kernel_size: (int, int)
|
263 |
+
Size of the convolutional kernel.
|
264 |
+
bias: bool
|
265 |
+
Whether or not to add the bias.
|
266 |
+
"""
|
267 |
+
|
268 |
+
super(ConvGruCell, self).__init__()
|
269 |
+
|
270 |
+
self.input_dim = input_dim
|
271 |
+
self.hidden_dim = hidden_dim
|
272 |
+
|
273 |
+
self.kernel_size = kernel_size
|
274 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
275 |
+
self.bias = bias
|
276 |
+
|
277 |
+
self.sigmoid = hardsigmoid
|
278 |
+
|
279 |
+
self.conv1 = nn.Conv2d(
|
280 |
+
in_channels=self.input_dim + self.hidden_dim,
|
281 |
+
out_channels=2 * self.hidden_dim,
|
282 |
+
kernel_size=self.kernel_size,
|
283 |
+
padding=self.padding,
|
284 |
+
bias=self.bias,
|
285 |
+
)
|
286 |
+
|
287 |
+
self.conv2 = nn.Conv2d(
|
288 |
+
in_channels=self.input_dim + self.hidden_dim,
|
289 |
+
out_channels=self.hidden_dim,
|
290 |
+
kernel_size=self.kernel_size,
|
291 |
+
padding=self.padding,
|
292 |
+
bias=self.bias,
|
293 |
+
)
|
294 |
+
|
295 |
+
def forward(self, input_tensor, cur_state):
|
296 |
+
h_cur = cur_state
|
297 |
+
|
298 |
+
# print(h_cur)
|
299 |
+
h_x = torch.cat([h_cur, input_tensor], dim=1) # concatenate along channel axis
|
300 |
+
|
301 |
+
# print('OK')
|
302 |
+
combined_conv = self.conv1(h_x)
|
303 |
+
cc_r, cc_u = torch.split(combined_conv, self.hidden_dim, dim=1)
|
304 |
+
r = self.sigmoid(cc_r)
|
305 |
+
u = self.sigmoid(cc_u)
|
306 |
+
|
307 |
+
x_r_o_h = torch.cat([input_tensor, r * h_cur], dim=1)
|
308 |
+
# print(x_r_o_h.size())
|
309 |
+
combined_conv = self.conv2(x_r_o_h)
|
310 |
+
|
311 |
+
c = nn.Tanh()(combined_conv)
|
312 |
+
h_next = (1 - u) * h_cur + u * c
|
313 |
+
|
314 |
+
return h_next
|
315 |
+
|
316 |
+
def init_hidden(self, batch_size, image_size):
|
317 |
+
height, width = image_size
|
318 |
+
return torch.zeros(
|
319 |
+
batch_size, self.hidden_dim, height, width, device=self.conv1.weight.device
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
class ConvGRU(nn.Module):
|
324 |
+
"""
|
325 |
+
|
326 |
+
Parameters:
|
327 |
+
input_dim: Number of channels in input
|
328 |
+
hidden_dim: Number of hidden channels
|
329 |
+
kernel_size: Size of kernel in convolutions
|
330 |
+
num_layers: Number of LSTM layers stacked on each other
|
331 |
+
batch_first: Whether or not dimension 0 is the batch or not
|
332 |
+
bias: Bias or no bias in Convolution
|
333 |
+
return_all_layers: Return the list of computations for all layers
|
334 |
+
Note: Will do same padding.
|
335 |
+
|
336 |
+
Input:
|
337 |
+
A tensor of size B, T, C, H, W or T, B, C, H, W
|
338 |
+
Output:
|
339 |
+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
340 |
+
0 - layer_output_list is the list of lists of length T of each output
|
341 |
+
1 - last_state_list is the list of last states
|
342 |
+
each element of the list is a tuple (h, c) for hidden state and memory
|
343 |
+
Example:
|
344 |
+
>> x = torch.rand((32, 10, 64, 128, 128))
|
345 |
+
>> convgru = ConvGRU(64, 16, 3, 1, True, True, False)
|
346 |
+
>> _, last_states = convgru(x)
|
347 |
+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
348 |
+
"""
|
349 |
+
|
350 |
+
def __init__(
|
351 |
+
self,
|
352 |
+
input_dim,
|
353 |
+
hidden_dim,
|
354 |
+
kernel_size,
|
355 |
+
num_layers,
|
356 |
+
batch_first=False,
|
357 |
+
bias=True,
|
358 |
+
return_all_layers=False,
|
359 |
+
):
|
360 |
+
super(ConvGRU, self).__init__()
|
361 |
+
|
362 |
+
self._check_kernel_size_consistency(kernel_size)
|
363 |
+
|
364 |
+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
365 |
+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
|
366 |
+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
367 |
+
if not len(kernel_size) == len(hidden_dim) == num_layers:
|
368 |
+
raise ValueError("Inconsistent list length.")
|
369 |
+
|
370 |
+
self.input_dim = input_dim
|
371 |
+
self.hidden_dim = hidden_dim
|
372 |
+
self.kernel_size = kernel_size
|
373 |
+
self.num_layers = num_layers
|
374 |
+
self.batch_first = batch_first
|
375 |
+
self.bias = bias
|
376 |
+
self.return_all_layers = return_all_layers
|
377 |
+
|
378 |
+
cell_list = []
|
379 |
+
for i in range(0, self.num_layers):
|
380 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
381 |
+
|
382 |
+
cell_list.append(
|
383 |
+
ConvGruCell(
|
384 |
+
input_dim=cur_input_dim,
|
385 |
+
hidden_dim=self.hidden_dim[i],
|
386 |
+
kernel_size=self.kernel_size[i],
|
387 |
+
bias=self.bias,
|
388 |
+
)
|
389 |
+
)
|
390 |
+
|
391 |
+
self.cell_list = nn.ModuleList(cell_list)
|
392 |
+
|
393 |
+
def forward(self, input_tensor, hidden_state=None):
|
394 |
+
"""
|
395 |
+
|
396 |
+
Parameters
|
397 |
+
----------
|
398 |
+
input_tensor: todo
|
399 |
+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
|
400 |
+
hidden_state: todo
|
401 |
+
None. todo implement stateful
|
402 |
+
|
403 |
+
Returns
|
404 |
+
-------
|
405 |
+
last_state_list, layer_output
|
406 |
+
"""
|
407 |
+
if not self.batch_first:
|
408 |
+
# (t, b, c, h, w) -> (b, t, c, h, w)
|
409 |
+
input_tensor = input_tensor.transpose(0, 1)
|
410 |
+
|
411 |
+
b, _, _, h, w = input_tensor.size()
|
412 |
+
|
413 |
+
# Implement stateful ConvGRU
|
414 |
+
if hidden_state is not None:
|
415 |
+
raise NotImplementedError()
|
416 |
+
else:
|
417 |
+
# Since the init is done in forward. Can send image size here
|
418 |
+
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
|
419 |
+
|
420 |
+
layer_output_list = []
|
421 |
+
last_state_list = []
|
422 |
+
|
423 |
+
seq_len = input_tensor.size(1)
|
424 |
+
cur_layer_input = input_tensor
|
425 |
+
|
426 |
+
for layer_idx in range(self.num_layers):
|
427 |
+
|
428 |
+
h = hidden_state[layer_idx]
|
429 |
+
output_inner = []
|
430 |
+
for t in range(seq_len):
|
431 |
+
h = self.cell_list[layer_idx](
|
432 |
+
input_tensor=cur_layer_input[:, t, :, :, :], cur_state=h
|
433 |
+
)
|
434 |
+
output_inner.append(h)
|
435 |
+
|
436 |
+
layer_output = torch.stack(output_inner, dim=1)
|
437 |
+
cur_layer_input = layer_output
|
438 |
+
|
439 |
+
layer_output_list.append(layer_output)
|
440 |
+
last_state_list.append(h)
|
441 |
+
|
442 |
+
if not self.return_all_layers:
|
443 |
+
layer_output_list = layer_output_list[-1:]
|
444 |
+
last_state_list = last_state_list[-1:]
|
445 |
+
|
446 |
+
return layer_output_list, last_state_list
|
447 |
+
|
448 |
+
def _init_hidden(self, batch_size, image_size):
|
449 |
+
init_states = []
|
450 |
+
for i in range(self.num_layers):
|
451 |
+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
|
452 |
+
return init_states
|
453 |
+
|
454 |
+
@staticmethod
|
455 |
+
def _check_kernel_size_consistency(kernel_size):
|
456 |
+
if not (
|
457 |
+
isinstance(kernel_size, tuple)
|
458 |
+
or (
|
459 |
+
isinstance(kernel_size, list)
|
460 |
+
and all([isinstance(elem, tuple) for elem in kernel_size])
|
461 |
+
)
|
462 |
+
):
|
463 |
+
raise ValueError("`kernel_size` must be tuple or list of tuples")
|
464 |
+
|
465 |
+
@staticmethod
|
466 |
+
def _extend_for_multilayer(param, num_layers):
|
467 |
+
if not isinstance(param, list):
|
468 |
+
param = [param] * num_layers
|
469 |
+
return param
|
470 |
+
|
471 |
+
|
472 |
+
## Symmetric padding (not existing natively in Pytorch)
|
473 |
+
## Source : https://discuss.pytorch.org/t/symmetric-padding/19866/3
|
474 |
+
|
475 |
+
|
476 |
+
def reflect(x, minx, maxx):
|
477 |
+
"""Reflects an array around two points making a triangular waveform that ramps up
|
478 |
+
and down, allowing for pad lengths greater than the input length"""
|
479 |
+
rng = maxx - minx
|
480 |
+
double_rng = 2 * rng
|
481 |
+
mod = np.fmod(x - minx, double_rng)
|
482 |
+
normed_mod = np.where(mod < 0, mod + double_rng, mod)
|
483 |
+
out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx
|
484 |
+
return np.array(out, dtype=x.dtype)
|
485 |
+
|
486 |
+
|
487 |
+
def symm_pad(im, padding):
|
488 |
+
h, w = im.shape[-2:]
|
489 |
+
left, right, top, bottom = padding
|
490 |
+
|
491 |
+
x_idx = np.arange(-left, w + right)
|
492 |
+
y_idx = np.arange(-top, h + bottom)
|
493 |
+
|
494 |
+
x_pad = reflect(x_idx, -0.5, w - 0.5)
|
495 |
+
y_pad = reflect(y_idx, -0.5, h - 0.5)
|
496 |
+
xx, yy = np.meshgrid(x_pad, y_pad)
|
497 |
+
return im[..., yy, xx]
|
498 |
+
|
499 |
+
|
500 |
+
# batch normalization equivalent to the one proposed in tensorflow
|
501 |
+
# Source : https://gluon.mxnet.io/chapter04_convolutional-neural-networks/cnn-batch-norm-scratch.html
|
502 |
+
|
503 |
+
|
504 |
+
def batch_norm(X, eps=0.001):
|
505 |
+
# extract the dimensions
|
506 |
+
N, C, H, W = X.shape
|
507 |
+
device = X.device
|
508 |
+
# mini-batch mean
|
509 |
+
mean = X.mean(axis=(0, 2, 3)).to(device)
|
510 |
+
# mini-batch variance
|
511 |
+
variance = ((X - mean.view((1, C, 1, 1))) ** 2).mean(axis=(0, 2, 3)).to(device)
|
512 |
+
# normalize
|
513 |
+
X = (
|
514 |
+
(X - mean.reshape((1, C, 1, 1)))
|
515 |
+
* 1.0
|
516 |
+
/ torch.pow((variance.view((1, C, 1, 1)) + eps), 0.5)
|
517 |
+
)
|
518 |
+
return X.to(device)
|
519 |
+
|
520 |
+
|
521 |
+
# MantraNet (equivalent from the one coded in tensorflow at https://github.com/ISICV/ManTraNet)
|
522 |
+
class MantraNet(nn.Module):
|
523 |
+
def __init__(self, in_channel=3, eps=10 ** (-6), device=device):
|
524 |
+
super(MantraNet, self).__init__()
|
525 |
+
|
526 |
+
self.eps = eps
|
527 |
+
self.relu = nn.ReLU()
|
528 |
+
self.device = device
|
529 |
+
|
530 |
+
# ********** IMAGE MANIPULATION TRACE FEATURE EXTRACTOR *********
|
531 |
+
|
532 |
+
## Initialisation
|
533 |
+
|
534 |
+
self.init_conv = nn.Conv2d(in_channel, 4, 5, 1, padding=0, bias=False)
|
535 |
+
|
536 |
+
self.BayarConv2D = nn.Conv2d(in_channel, 3, 5, 1, padding=0, bias=False)
|
537 |
+
self.bayar_mask = (torch.tensor(np.ones(shape=(5, 5)))).to(self.device)
|
538 |
+
self.bayar_mask[2, 2] = 0
|
539 |
+
|
540 |
+
self.bayar_final = (torch.tensor(np.zeros((5, 5)))).to(self.device)
|
541 |
+
self.bayar_final[2, 2] = -1
|
542 |
+
|
543 |
+
self.SRMConv2D = nn.Conv2d(in_channel, 9, 5, 1, padding=0, bias=False)
|
544 |
+
self.SRMConv2D.weight.data = torch.load("MantraNet/MantraNetv4.pt")[
|
545 |
+
"SRMConv2D.weight"
|
546 |
+
]
|
547 |
+
|
548 |
+
##SRM filters (fixed)
|
549 |
+
for param in self.SRMConv2D.parameters():
|
550 |
+
param.requires_grad = False
|
551 |
+
|
552 |
+
self.middle_and_last_block = nn.ModuleList(
|
553 |
+
[
|
554 |
+
nn.Conv2d(16, 32, 3, 1, padding=0),
|
555 |
+
nn.ReLU(),
|
556 |
+
nn.Conv2d(32, 64, 3, 1, padding=0),
|
557 |
+
nn.ReLU(),
|
558 |
+
nn.Conv2d(64, 64, 3, 1, padding=0),
|
559 |
+
nn.ReLU(),
|
560 |
+
nn.Conv2d(64, 128, 3, 1, padding=0),
|
561 |
+
nn.ReLU(),
|
562 |
+
nn.Conv2d(128, 128, 3, 1, padding=0),
|
563 |
+
nn.ReLU(),
|
564 |
+
nn.Conv2d(128, 128, 3, 1, padding=0),
|
565 |
+
nn.ReLU(),
|
566 |
+
nn.Conv2d(128, 256, 3, 1, padding=0),
|
567 |
+
nn.ReLU(),
|
568 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
569 |
+
nn.ReLU(),
|
570 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
571 |
+
nn.ReLU(),
|
572 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
573 |
+
nn.ReLU(),
|
574 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
575 |
+
nn.ReLU(),
|
576 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
577 |
+
]
|
578 |
+
)
|
579 |
+
|
580 |
+
# ********** LOCAL ANOMALY DETECTOR *********
|
581 |
+
|
582 |
+
self.adaptation = nn.Conv2d(256, 64, 1, 1, padding=0, bias=False)
|
583 |
+
|
584 |
+
self.sigma_F = nn.Parameter(torch.zeros((1, 64, 1, 1)), requires_grad=True)
|
585 |
+
|
586 |
+
self.pool31 = nn.AvgPool2d(31, stride=1, padding=15, count_include_pad=False)
|
587 |
+
self.pool15 = nn.AvgPool2d(15, stride=1, padding=7, count_include_pad=False)
|
588 |
+
self.pool7 = nn.AvgPool2d(7, stride=1, padding=3, count_include_pad=False)
|
589 |
+
|
590 |
+
self.convlstm = ConvLSTM(
|
591 |
+
input_dim=64,
|
592 |
+
hidden_dim=8,
|
593 |
+
kernel_size=(7, 7),
|
594 |
+
num_layers=1,
|
595 |
+
batch_first=False,
|
596 |
+
bias=True,
|
597 |
+
return_all_layers=False,
|
598 |
+
)
|
599 |
+
|
600 |
+
self.end = nn.Sequential(nn.Conv2d(8, 1, 7, 1, padding=3), nn.Sigmoid())
|
601 |
+
|
602 |
+
def forward(self, x):
|
603 |
+
B, nb_channel, H, W = x.shape
|
604 |
+
|
605 |
+
if not (self.training):
|
606 |
+
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
|
607 |
+
else:
|
608 |
+
if not hasattr(self, "GlobalPool"):
|
609 |
+
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
|
610 |
+
|
611 |
+
# Normalization
|
612 |
+
x = x / 255.0 * 2 - 1
|
613 |
+
|
614 |
+
## Image Manipulation Trace Feature Extractor
|
615 |
+
|
616 |
+
## **Bayar constraints**
|
617 |
+
|
618 |
+
self.BayarConv2D.weight.data *= self.bayar_mask
|
619 |
+
self.BayarConv2D.weight.data *= torch.pow(
|
620 |
+
self.BayarConv2D.weight.data.sum(axis=(2, 3)).view(3, 3, 1, 1), -1
|
621 |
+
)
|
622 |
+
self.BayarConv2D.weight.data += self.bayar_final
|
623 |
+
|
624 |
+
# Symmetric padding
|
625 |
+
x = symm_pad(x, (2, 2, 2, 2))
|
626 |
+
|
627 |
+
conv_init = self.init_conv(x)
|
628 |
+
conv_bayar = self.BayarConv2D(x)
|
629 |
+
conv_srm = self.SRMConv2D(x)
|
630 |
+
|
631 |
+
first_block = torch.cat([conv_init, conv_srm, conv_bayar], axis=1)
|
632 |
+
first_block = self.relu(first_block)
|
633 |
+
|
634 |
+
last_block = first_block
|
635 |
+
|
636 |
+
for layer in self.middle_and_last_block:
|
637 |
+
|
638 |
+
if isinstance(layer, nn.Conv2d):
|
639 |
+
last_block = symm_pad(last_block, (1, 1, 1, 1))
|
640 |
+
|
641 |
+
last_block = layer(last_block)
|
642 |
+
|
643 |
+
# L2 normalization
|
644 |
+
last_block = F.normalize(last_block, dim=1, p=2)
|
645 |
+
|
646 |
+
## Local Anomaly Feature Extraction
|
647 |
+
X_adapt = self.adaptation(last_block)
|
648 |
+
X_adapt = batch_norm(X_adapt)
|
649 |
+
|
650 |
+
# Z-pool concatenation
|
651 |
+
mu_T = self.GlobalPool(X_adapt)
|
652 |
+
sigma_T = torch.sqrt(self.GlobalPool(torch.square(X_adapt - mu_T)))
|
653 |
+
sigma_T = torch.max(sigma_T, self.sigma_F + self.eps)
|
654 |
+
inv_sigma_T = torch.pow(sigma_T, -1)
|
655 |
+
zpoolglobal = torch.abs((mu_T - X_adapt) * inv_sigma_T)
|
656 |
+
|
657 |
+
mu_31 = self.pool31(X_adapt)
|
658 |
+
zpool31 = torch.abs((mu_31 - X_adapt) * inv_sigma_T)
|
659 |
+
|
660 |
+
mu_15 = self.pool15(X_adapt)
|
661 |
+
zpool15 = torch.abs((mu_15 - X_adapt) * inv_sigma_T)
|
662 |
+
|
663 |
+
mu_7 = self.pool7(X_adapt)
|
664 |
+
zpool7 = torch.abs((mu_7 - X_adapt) * inv_sigma_T)
|
665 |
+
|
666 |
+
input_lstm = torch.cat(
|
667 |
+
[
|
668 |
+
zpool7.unsqueeze(0),
|
669 |
+
zpool15.unsqueeze(0),
|
670 |
+
zpool31.unsqueeze(0),
|
671 |
+
zpoolglobal.unsqueeze(0),
|
672 |
+
],
|
673 |
+
axis=0,
|
674 |
+
)
|
675 |
+
|
676 |
+
# Conv2DLSTM
|
677 |
+
_, output_lstm = self.convlstm(input_lstm)
|
678 |
+
output_lstm = output_lstm[0][0]
|
679 |
+
|
680 |
+
final_output = self.end(output_lstm)
|
681 |
+
|
682 |
+
return final_output
|
683 |
+
|
684 |
+
|
685 |
+
# Slight modification of the original MantraNet using a GRU instead of a LSTM
|
686 |
+
class MantraNet_GRU(nn.Module):
|
687 |
+
def __init__(self, device, in_channel=3, eps=10 ** (-4)):
|
688 |
+
super(MantraNet_GRU, self).__init__()
|
689 |
+
|
690 |
+
self.eps = eps
|
691 |
+
self.relu = nn.ReLU()
|
692 |
+
self.device = device
|
693 |
+
|
694 |
+
# ********** IMAGE MANIPULATION TRACE FEATURE EXTRACTOR *********
|
695 |
+
|
696 |
+
## Initialisation
|
697 |
+
|
698 |
+
self.init_conv = nn.Conv2d(in_channel, 4, 5, 1, padding=0, bias=False)
|
699 |
+
|
700 |
+
self.BayarConv2D = nn.Conv2d(in_channel, 3, 5, 1, padding=0, bias=False)
|
701 |
+
|
702 |
+
self.SRMConv2D = nn.Conv2d(in_channel, 9, 5, 1, padding=0, bias=False)
|
703 |
+
|
704 |
+
self.SRMConv2D.weight.data = torch.load("MantraNetv4.pt")["SRMConv2D.weight"]
|
705 |
+
|
706 |
+
##SRM filters (fixed)
|
707 |
+
for param in self.SRMConv2D.parameters():
|
708 |
+
param.requires_grad = False
|
709 |
+
|
710 |
+
self.middle_and_last_block = nn.ModuleList(
|
711 |
+
[
|
712 |
+
nn.Conv2d(16, 32, 3, 1, padding=0),
|
713 |
+
nn.ReLU(),
|
714 |
+
nn.Conv2d(32, 64, 3, 1, padding=0),
|
715 |
+
nn.ReLU(),
|
716 |
+
nn.Conv2d(64, 64, 3, 1, padding=0),
|
717 |
+
nn.ReLU(),
|
718 |
+
nn.Conv2d(64, 128, 3, 1, padding=0),
|
719 |
+
nn.ReLU(),
|
720 |
+
nn.Conv2d(128, 128, 3, 1, padding=0),
|
721 |
+
nn.ReLU(),
|
722 |
+
nn.Conv2d(128, 128, 3, 1, padding=0),
|
723 |
+
nn.ReLU(),
|
724 |
+
nn.Conv2d(128, 256, 3, 1, padding=0),
|
725 |
+
nn.ReLU(),
|
726 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
727 |
+
nn.ReLU(),
|
728 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
729 |
+
nn.ReLU(),
|
730 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
731 |
+
nn.ReLU(),
|
732 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
733 |
+
nn.ReLU(),
|
734 |
+
nn.Conv2d(256, 256, 3, 1, padding=0),
|
735 |
+
]
|
736 |
+
)
|
737 |
+
|
738 |
+
# ********** LOCAL ANOMALY DETECTOR *********
|
739 |
+
|
740 |
+
self.adaptation = nn.Conv2d(256, 64, 1, 1, padding=0, bias=False)
|
741 |
+
|
742 |
+
self.sigma_F = nn.Parameter(torch.zeros((1, 64, 1, 1)), requires_grad=True)
|
743 |
+
|
744 |
+
self.pool31 = nn.AvgPool2d(31, stride=1, padding=15, count_include_pad=False)
|
745 |
+
self.pool15 = nn.AvgPool2d(15, stride=1, padding=7, count_include_pad=False)
|
746 |
+
self.pool7 = nn.AvgPool2d(7, stride=1, padding=3, count_include_pad=False)
|
747 |
+
|
748 |
+
self.convgru = ConvGRU(
|
749 |
+
input_dim=64,
|
750 |
+
hidden_dim=8,
|
751 |
+
kernel_size=(7, 7),
|
752 |
+
num_layers=1,
|
753 |
+
batch_first=False,
|
754 |
+
bias=True,
|
755 |
+
return_all_layers=False,
|
756 |
+
)
|
757 |
+
|
758 |
+
self.end = nn.Sequential(nn.Conv2d(8, 1, 7, 1, padding=3), nn.Sigmoid())
|
759 |
+
|
760 |
+
self.bayar_mask = torch.ones((5, 5), device=self.device)
|
761 |
+
self.bayar_final = torch.zeros((5, 5), device=self.device)
|
762 |
+
|
763 |
+
def forward(self, x):
|
764 |
+
B, nb_channel, H, W = x.shape
|
765 |
+
|
766 |
+
if not (self.training):
|
767 |
+
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
|
768 |
+
else:
|
769 |
+
if not hasattr(self, "GlobalPool"):
|
770 |
+
self.GlobalPool = nn.AvgPool2d((H, W), stride=1)
|
771 |
+
|
772 |
+
# Normalization
|
773 |
+
x = x / 255.0 * 2 - 1
|
774 |
+
|
775 |
+
## Image Manipulation Trace Feature Extractor
|
776 |
+
|
777 |
+
## **Bayar constraints**
|
778 |
+
|
779 |
+
self.bayar_mask[2, 2] = 0
|
780 |
+
self.bayar_final[2, 2] = -1
|
781 |
+
|
782 |
+
self.BayarConv2D.weight.data *= self.bayar_mask
|
783 |
+
self.BayarConv2D.weight.data *= torch.pow(
|
784 |
+
self.BayarConv2D.weight.data.sum(axis=(2, 3)).view(3, 3, 1, 1), -1
|
785 |
+
)
|
786 |
+
self.BayarConv2D.weight.data += self.bayar_final
|
787 |
+
|
788 |
+
# Symmetric padding
|
789 |
+
X = symm_pad(x, (2, 2, 2, 2))
|
790 |
+
|
791 |
+
conv_init = self.init_conv(X)
|
792 |
+
conv_bayar = self.BayarConv2D(X)
|
793 |
+
conv_srm = self.SRMConv2D(X)
|
794 |
+
|
795 |
+
first_block = torch.cat([conv_init, conv_srm, conv_bayar], axis=1)
|
796 |
+
first_block = self.relu(first_block)
|
797 |
+
|
798 |
+
last_block = first_block
|
799 |
+
|
800 |
+
for layer in self.middle_and_last_block:
|
801 |
+
|
802 |
+
if isinstance(layer, nn.Conv2d):
|
803 |
+
last_block = symm_pad(last_block, (1, 1, 1, 1))
|
804 |
+
|
805 |
+
last_block = layer(last_block)
|
806 |
+
|
807 |
+
# L2 normalization
|
808 |
+
last_block = F.normalize(last_block, dim=1, p=2)
|
809 |
+
|
810 |
+
## Local Anomaly Feature Extraction
|
811 |
+
X_adapt = self.adaptation(last_block)
|
812 |
+
X_adapt = batch_norm(X_adapt)
|
813 |
+
|
814 |
+
# Z-pool concatenation
|
815 |
+
mu_T = self.GlobalPool(X_adapt)
|
816 |
+
sigma_T = torch.sqrt(self.GlobalPool(torch.square(X_adapt - mu_T)))
|
817 |
+
sigma_T = torch.max(sigma_T, self.sigma_F + self.eps)
|
818 |
+
inv_sigma_T = torch.pow(sigma_T, -1)
|
819 |
+
zpoolglobal = torch.abs((mu_T - X_adapt) * inv_sigma_T)
|
820 |
+
|
821 |
+
mu_31 = self.pool31(X_adapt)
|
822 |
+
zpool31 = torch.abs((mu_31 - X_adapt) * inv_sigma_T)
|
823 |
+
|
824 |
+
mu_15 = self.pool15(X_adapt)
|
825 |
+
zpool15 = torch.abs((mu_15 - X_adapt) * inv_sigma_T)
|
826 |
+
|
827 |
+
mu_7 = self.pool7(X_adapt)
|
828 |
+
zpool7 = torch.abs((mu_7 - X_adapt) * inv_sigma_T)
|
829 |
+
|
830 |
+
input_gru = torch.cat(
|
831 |
+
[
|
832 |
+
zpool7.unsqueeze(0),
|
833 |
+
zpool15.unsqueeze(0),
|
834 |
+
zpool31.unsqueeze(0),
|
835 |
+
zpoolglobal.unsqueeze(0),
|
836 |
+
],
|
837 |
+
axis=0,
|
838 |
+
)
|
839 |
+
|
840 |
+
# Conv2DLSTM
|
841 |
+
_, output_gru = self.convgru(input_gru)
|
842 |
+
output_gru = output_gru[0]
|
843 |
+
|
844 |
+
final_output = self.end(output_gru)
|
845 |
+
|
846 |
+
return final_output
|
847 |
+
|
848 |
+
|
849 |
+
##Use pre-trained weights :
|
850 |
+
def pre_trained_model(weight_path="MantraNet\MantraNetv4.pt", device=device):
|
851 |
+
model = MantraNet(device=device)
|
852 |
+
model.load_state_dict(torch.load(weight_path))
|
853 |
+
return model
|
854 |
+
|
855 |
+
|
856 |
+
# predict a forgery mask of an image
|
857 |
+
def check_forgery(model, img_path="./example.jpg", device=device):
|
858 |
+
|
859 |
+
model.to(device)
|
860 |
+
model.eval()
|
861 |
+
|
862 |
+
im = Image.open(img_path)
|
863 |
+
im = np.array(im)
|
864 |
+
original_image = im.copy()
|
865 |
+
|
866 |
+
im = torch.Tensor(im)
|
867 |
+
im = im.unsqueeze(0)
|
868 |
+
im = im.transpose(2, 3).transpose(1, 2)
|
869 |
+
im = im.to(device)
|
870 |
+
|
871 |
+
with torch.no_grad():
|
872 |
+
final_output = model(im)
|
873 |
+
|
874 |
+
fig = plt.figure(figsize=(20, 20))
|
875 |
+
|
876 |
+
plt.subplot(1, 3, 1)
|
877 |
+
plt.imshow(original_image)
|
878 |
+
plt.title("Original image")
|
879 |
+
|
880 |
+
plt.subplot(1, 3, 2)
|
881 |
+
plt.imshow((final_output[0][0]).cpu().detach(), cmap="gray")
|
882 |
+
plt.title("Predicted forgery mask")
|
883 |
+
|
884 |
+
plt.subplot(1, 3, 3)
|
885 |
+
plt.imshow(
|
886 |
+
(final_output[0][0].cpu().detach().unsqueeze(2) > 0.2)
|
887 |
+
* torch.tensor(original_image)
|
888 |
+
)
|
889 |
+
plt.title("Suspicious regions detected")
|
890 |
+
|
891 |
+
return fig
|
892 |
+
|
893 |
+
|
894 |
+
class ForgeryDetector(pl.LightningModule):
|
895 |
+
|
896 |
+
# Model Initialization/Creation
|
897 |
+
def __init__(self, train_loader, detector=MantraNet(), lr=0.001):
|
898 |
+
super(ForgeryDetector, self).__init__()
|
899 |
+
|
900 |
+
self.detector = detector
|
901 |
+
self.train_loader = train_loader
|
902 |
+
self.cpt = -1
|
903 |
+
self.lr = lr
|
904 |
+
|
905 |
+
# Forward Pass of Model
|
906 |
+
def forward(self, x):
|
907 |
+
return self.detector(x)
|
908 |
+
|
909 |
+
# Loss Function
|
910 |
+
def loss(self, y_hat, y):
|
911 |
+
return nn.BCELoss()(y_hat, y)
|
912 |
+
|
913 |
+
# Optimizers
|
914 |
+
def configure_optimizers(self):
|
915 |
+
optimizer = torch.optim.AdamW(self.detector.parameters(), lr=self.lr)
|
916 |
+
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
917 |
+
|
918 |
+
# return the list of optimizers and second empty list is for schedulers (if any)
|
919 |
+
return [optimizer], []
|
920 |
+
|
921 |
+
# Calls after prepare_data for DataLoader
|
922 |
+
def train_dataloader(self):
|
923 |
+
|
924 |
+
return self.train_loader
|
925 |
+
|
926 |
+
# Training Loop
|
927 |
+
def training_step(self, batch, batch_idx):
|
928 |
+
# batch returns x and y tensors
|
929 |
+
real_images, mask = batch
|
930 |
+
B, _, _, _ = real_images.size()
|
931 |
+
self.cpt += 1
|
932 |
+
|
933 |
+
predicted = self.detector(real_images).view(B, -1)
|
934 |
+
mask = mask.view(B, -1)
|
935 |
+
|
936 |
+
loss = self.loss(predicted, mask)
|
937 |
+
|
938 |
+
self.log("BCELoss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
939 |
+
|
940 |
+
output = OrderedDict(
|
941 |
+
{
|
942 |
+
"loss": loss,
|
943 |
+
}
|
944 |
+
)
|
945 |
+
|
946 |
+
return output
|
__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
1 |
+
from MantraNet.mantranet import pre_trained_model, check_forgery
|
2 |
+
from app import check_image
|
app.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from MantraNet.mantranet import pre_trained_model, check_forgery
|
4 |
+
from BusterNet.BusterNetCore import create_BusterNet_testing_model
|
5 |
+
from BusterNet.BusterNetUtils import simple_cmfd_decoder, visualize_result
|
6 |
+
import streamlit as st
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
10 |
+
|
11 |
+
st.header("IMD Demo")
|
12 |
+
device = "cpu" # to change if you have a GPU with at least 12Go RAM (it will save you a lot of time !)
|
13 |
+
|
14 |
+
def check_image_buster(img_path):
|
15 |
+
busterNetModel = create_BusterNet_testing_model( 'BusterNet/pretrained_busterNet.hd5' )
|
16 |
+
rgb = cv2.imread(img_path)
|
17 |
+
pred = simple_cmfd_decoder( busterNetModel, rgb )
|
18 |
+
figure = visualize_result( rgb, pred, pred, figsize=(20,20), title='BusterNet CMFD')
|
19 |
+
st.pyplot(figure)
|
20 |
+
|
21 |
+
def check_image_mantra(img_path):
|
22 |
+
device = "cpu" # to change if you have a GPU with at least 12Go RAM (it will save you a lot of time !)
|
23 |
+
MantraNetmodel = pre_trained_model(
|
24 |
+
weight_path="MantraNet/MantraNetv4.pt", device=device
|
25 |
+
)
|
26 |
+
fig = check_forgery(MantraNetmodel, img_path=img_path, device=device)
|
27 |
+
st.pyplot(fig)
|
28 |
+
|
29 |
+
|
30 |
+
uploaded_image = st.file_uploader("Upload your image", type=["jpg", "png","jpeg"])
|
31 |
+
if uploaded_image is not None:
|
32 |
+
with open(os.path.join("images", uploaded_image.name), "wb") as f:
|
33 |
+
f.write(uploaded_image.read())
|
34 |
+
st.write("BusterNet")
|
35 |
+
check_image_buster(os.path.join("images", uploaded_image.name))
|
36 |
+
st.write("MantraNet")
|
37 |
+
check_image_mantra(os.path.join("images", uploaded_image.name))
|