aliabd
full demo working
56aa5b9
# Copyright (c) 2021 Anish Athalye. Released under the MIT license.
import numpy as np
import tensorflow as tf
from scipy.ndimage.filters import gaussian_filter
import argparse
import os
import time
from util import *
DEFAULT_MODEL_PATH = 'model.onnx'
DEFAULT_SEED_PATH = 'neuralhash_128x96_seed1.dat'
DEFAULT_TARGET_HASH = '59a34eabe31910abfb06f308'
DEFAULT_ITERATIONS = 500
DEFAULT_SAVE_ITERATIONS = 0
DEFAULT_LR = 2.0
DEFAULT_COMBINED_THRESHOLD = 2
DEFAULT_K = 10.0
DEFAULT_CLIP_RANGE = 0.1
DEFAULT_W_L2 = 2e-3
DEFAULT_W_TV = 1e-4
DEFAULT_W_HASH = 0.8
DEFAULT_BLUR = 0
def main():
tf.compat.v1.disable_eager_execution()
options = get_options()
model = load_model(options.model)
image = model.tensor_dict['image']
logits = model.tensor_dict['leaf/logits']
seed = load_seed(options.seed)
target = hash_from_hex(options.target)
original = load_image(options.image)
h = hash_from_hex(options.target)
with model.graph.as_default():
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
proj = tf.reshape(tf.linalg.matmul(seed, tf.reshape(logits, (128, 1))), (96,))
# proj is in R^96; it's interpreted as a 96-bit hash by mapping
# entries < 0 to the bit '0', and entries >= 0 to the bit '1'
normalized, _ = tf.linalg.normalize(proj)
hash_output = tf.sigmoid(normalized * options.k)
# now, hash_output has entries in (0, 1); it's interpreted by
# mapping entries < 0.5 to the bit '0' and entries >= 0.5 to the
# bit '1'
# we clip hash_output to (clip_range, 1-clip_range); this seems to
# improve the search (we don't "waste" perturbation tweaking
# "strong" bits); the sigmoid already does this to some degree, but
# this seems to help
hash_output = tf.clip_by_value(hash_output, options.clip_range, 1.0 - options.clip_range) - 0.5
hash_output = hash_output * (0.5 / (0.5 - options.clip_range))
hash_output = hash_output + 0.5
# hash loss: how far away we are from the target hash
hash_loss = tf.math.reduce_sum(tf.math.squared_difference(hash_output, h))
perturbation = image - original
# image loss: how big / noticeable is the perturbation?
img_loss = options.l2_weight * tf.nn.l2_loss(perturbation) + options.tv_weight * tf.image.total_variation(perturbation)[0]
# combined loss: try to minimize both at once
combined_loss = options.hash_weight * hash_loss + (1 - options.hash_weight) * img_loss
# gradients of all the losses
g_hash_loss, = tf.gradients(hash_loss, image)
g_img_loss, = tf.gradients(img_loss, image)
g_combined_loss, = tf.gradients(combined_loss, image)
# perform attack
x = original
best = (float('inf'), 0) # (distance, image quality loss)
dist = float('inf')
for i in range(options.iterations):
# we do an alternating projections style attack here; if we
# haven't found a colliding image yet, only optimize for that;
# if we have a colliding image, then minimize the size of the
# perturbation; if we're close, then do both at once
if dist == 0:
loss_name, loss, g = 'image', img_loss, g_img_loss
elif best[0] == 0 and dist <= options.combined_threshold:
loss_name, loss, g = 'combined', combined_loss, g_combined_loss
else:
loss_name, loss, g = 'hash', hash_loss, g_hash_loss
# compute loss values and gradient
xq = quantize(x) # take derivatives wrt the quantized version of the image
hash_output_v, img_loss_v, loss_v, g_v = sess.run([hash_output, img_loss, loss, g], feed_dict={image: xq})
dist = np.sum((hash_output_v >= 0.5) != (h >= 0.5))
if dist == 0:
save_image(x, os.path.join(options.save_directory, 'best.png'))
break
# if it's better than any image found so far, save it
score = (dist, img_loss_v)
if score < best or (options.save_iterations > 0 and (i+1) % options.save_iterations == 0):
# save_image(x, os.path.join(options.save_directory, 'best.png'))
pass
if score < best:
best = score
# gradient descent step
g_v_norm = g_v / np.linalg.norm(g_v)
x = x - options.learning_rate * g_v_norm
if options.blur > 0:
x = blur_perturbation(original, x, options.blur)
x = x.clip(-1, 1)
print('iteration: {}/{}, best: ({}, {:.3f}), hash: {}, distance: {}, loss: {:.3f} ({})'.format(
i+1,
options.iterations,
best[0],
best[1],
hash_to_hex(hash_output_v),
dist,
loss_v,
loss_name
))
def quantize(x):
x = (x + 1.0) * (255.0 / 2.0)
x = x.astype(np.uint8).astype(np.float32)
x = x / (255.0 / 2.0) - 1.0
return x
def blur_perturbation(original, x, sigma):
perturbation = x - original
perturbation = gaussian_filter_by_channel(perturbation, sigma=sigma)
return original + perturbation
def gaussian_filter_by_channel(x, sigma):
return np.stack([gaussian_filter(x[0,ch,:,:], sigma) for ch in range(x.shape[1])])[np.newaxis]
def get_options():
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, help='path to starting image', required=True)
parser.add_argument('--model', type=str, help='path to model', default=DEFAULT_MODEL_PATH)
parser.add_argument('--seed', type=str, help='path to seed', default=DEFAULT_SEED_PATH)
parser.add_argument('--target', type=str, help='target hash', default=DEFAULT_TARGET_HASH)
parser.add_argument('--learning-rate', type=float, help='learning rate', default=DEFAULT_LR)
parser.add_argument('--combined-threshold', type=int, help='threshold to start using combined loss', default=DEFAULT_COMBINED_THRESHOLD)
parser.add_argument('--k', type=float, help='k parameter', default=DEFAULT_K)
parser.add_argument('--l2-weight', type=float, help='perturbation l2 loss weight', default=DEFAULT_W_L2)
parser.add_argument('--tv-weight', type=float, help='perturbation total variation loss weight', default=DEFAULT_W_TV)
parser.add_argument('--hash-weight', type=float, help='relative weight (0.0 to 1.0) of hash in combined loss', default=DEFAULT_W_HASH)
parser.add_argument('--clip-range', type=float, help='clip range parameter', default=DEFAULT_CLIP_RANGE)
parser.add_argument('--iterations', type=int, help='max number of iterations', default=DEFAULT_ITERATIONS)
parser.add_argument('--save-directory', type=str, help='directory to save output images', default='.')
parser.add_argument('--save-iterations', type=int, help='save this frequently, regardless of improvement', default=DEFAULT_SAVE_ITERATIONS)
parser.add_argument('--blur', type=float, help='apply Gaussian blur with this sigma on every step', default=DEFAULT_BLUR)
return parser.parse_args()
if __name__ == '__main__':
start = time.time()
main()
end = time.time()
print(end - start)