Spaces:
Runtime error
Runtime error
# Copyright (c) 2021 Anish Athalye. Released under the MIT license. | |
import numpy as np | |
import tensorflow as tf | |
from scipy.ndimage.filters import gaussian_filter | |
import argparse | |
import os | |
import time | |
from util import * | |
DEFAULT_MODEL_PATH = 'model.onnx' | |
DEFAULT_SEED_PATH = 'neuralhash_128x96_seed1.dat' | |
DEFAULT_TARGET_HASH = '59a34eabe31910abfb06f308' | |
DEFAULT_ITERATIONS = 500 | |
DEFAULT_SAVE_ITERATIONS = 0 | |
DEFAULT_LR = 2.0 | |
DEFAULT_COMBINED_THRESHOLD = 2 | |
DEFAULT_K = 10.0 | |
DEFAULT_CLIP_RANGE = 0.1 | |
DEFAULT_W_L2 = 2e-3 | |
DEFAULT_W_TV = 1e-4 | |
DEFAULT_W_HASH = 0.8 | |
DEFAULT_BLUR = 0 | |
def main(): | |
tf.compat.v1.disable_eager_execution() | |
options = get_options() | |
model = load_model(options.model) | |
image = model.tensor_dict['image'] | |
logits = model.tensor_dict['leaf/logits'] | |
seed = load_seed(options.seed) | |
target = hash_from_hex(options.target) | |
original = load_image(options.image) | |
h = hash_from_hex(options.target) | |
with model.graph.as_default(): | |
with tf.compat.v1.Session() as sess: | |
sess.run(tf.compat.v1.global_variables_initializer()) | |
proj = tf.reshape(tf.linalg.matmul(seed, tf.reshape(logits, (128, 1))), (96,)) | |
# proj is in R^96; it's interpreted as a 96-bit hash by mapping | |
# entries < 0 to the bit '0', and entries >= 0 to the bit '1' | |
normalized, _ = tf.linalg.normalize(proj) | |
hash_output = tf.sigmoid(normalized * options.k) | |
# now, hash_output has entries in (0, 1); it's interpreted by | |
# mapping entries < 0.5 to the bit '0' and entries >= 0.5 to the | |
# bit '1' | |
# we clip hash_output to (clip_range, 1-clip_range); this seems to | |
# improve the search (we don't "waste" perturbation tweaking | |
# "strong" bits); the sigmoid already does this to some degree, but | |
# this seems to help | |
hash_output = tf.clip_by_value(hash_output, options.clip_range, 1.0 - options.clip_range) - 0.5 | |
hash_output = hash_output * (0.5 / (0.5 - options.clip_range)) | |
hash_output = hash_output + 0.5 | |
# hash loss: how far away we are from the target hash | |
hash_loss = tf.math.reduce_sum(tf.math.squared_difference(hash_output, h)) | |
perturbation = image - original | |
# image loss: how big / noticeable is the perturbation? | |
img_loss = options.l2_weight * tf.nn.l2_loss(perturbation) + options.tv_weight * tf.image.total_variation(perturbation)[0] | |
# combined loss: try to minimize both at once | |
combined_loss = options.hash_weight * hash_loss + (1 - options.hash_weight) * img_loss | |
# gradients of all the losses | |
g_hash_loss, = tf.gradients(hash_loss, image) | |
g_img_loss, = tf.gradients(img_loss, image) | |
g_combined_loss, = tf.gradients(combined_loss, image) | |
# perform attack | |
x = original | |
best = (float('inf'), 0) # (distance, image quality loss) | |
dist = float('inf') | |
for i in range(options.iterations): | |
# we do an alternating projections style attack here; if we | |
# haven't found a colliding image yet, only optimize for that; | |
# if we have a colliding image, then minimize the size of the | |
# perturbation; if we're close, then do both at once | |
if dist == 0: | |
loss_name, loss, g = 'image', img_loss, g_img_loss | |
elif best[0] == 0 and dist <= options.combined_threshold: | |
loss_name, loss, g = 'combined', combined_loss, g_combined_loss | |
else: | |
loss_name, loss, g = 'hash', hash_loss, g_hash_loss | |
# compute loss values and gradient | |
xq = quantize(x) # take derivatives wrt the quantized version of the image | |
hash_output_v, img_loss_v, loss_v, g_v = sess.run([hash_output, img_loss, loss, g], feed_dict={image: xq}) | |
dist = np.sum((hash_output_v >= 0.5) != (h >= 0.5)) | |
if dist == 0: | |
save_image(x, os.path.join(options.save_directory, 'best.png')) | |
break | |
# if it's better than any image found so far, save it | |
score = (dist, img_loss_v) | |
if score < best or (options.save_iterations > 0 and (i+1) % options.save_iterations == 0): | |
# save_image(x, os.path.join(options.save_directory, 'best.png')) | |
pass | |
if score < best: | |
best = score | |
# gradient descent step | |
g_v_norm = g_v / np.linalg.norm(g_v) | |
x = x - options.learning_rate * g_v_norm | |
if options.blur > 0: | |
x = blur_perturbation(original, x, options.blur) | |
x = x.clip(-1, 1) | |
print('iteration: {}/{}, best: ({}, {:.3f}), hash: {}, distance: {}, loss: {:.3f} ({})'.format( | |
i+1, | |
options.iterations, | |
best[0], | |
best[1], | |
hash_to_hex(hash_output_v), | |
dist, | |
loss_v, | |
loss_name | |
)) | |
def quantize(x): | |
x = (x + 1.0) * (255.0 / 2.0) | |
x = x.astype(np.uint8).astype(np.float32) | |
x = x / (255.0 / 2.0) - 1.0 | |
return x | |
def blur_perturbation(original, x, sigma): | |
perturbation = x - original | |
perturbation = gaussian_filter_by_channel(perturbation, sigma=sigma) | |
return original + perturbation | |
def gaussian_filter_by_channel(x, sigma): | |
return np.stack([gaussian_filter(x[0,ch,:,:], sigma) for ch in range(x.shape[1])])[np.newaxis] | |
def get_options(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image', type=str, help='path to starting image', required=True) | |
parser.add_argument('--model', type=str, help='path to model', default=DEFAULT_MODEL_PATH) | |
parser.add_argument('--seed', type=str, help='path to seed', default=DEFAULT_SEED_PATH) | |
parser.add_argument('--target', type=str, help='target hash', default=DEFAULT_TARGET_HASH) | |
parser.add_argument('--learning-rate', type=float, help='learning rate', default=DEFAULT_LR) | |
parser.add_argument('--combined-threshold', type=int, help='threshold to start using combined loss', default=DEFAULT_COMBINED_THRESHOLD) | |
parser.add_argument('--k', type=float, help='k parameter', default=DEFAULT_K) | |
parser.add_argument('--l2-weight', type=float, help='perturbation l2 loss weight', default=DEFAULT_W_L2) | |
parser.add_argument('--tv-weight', type=float, help='perturbation total variation loss weight', default=DEFAULT_W_TV) | |
parser.add_argument('--hash-weight', type=float, help='relative weight (0.0 to 1.0) of hash in combined loss', default=DEFAULT_W_HASH) | |
parser.add_argument('--clip-range', type=float, help='clip range parameter', default=DEFAULT_CLIP_RANGE) | |
parser.add_argument('--iterations', type=int, help='max number of iterations', default=DEFAULT_ITERATIONS) | |
parser.add_argument('--save-directory', type=str, help='directory to save output images', default='.') | |
parser.add_argument('--save-iterations', type=int, help='save this frequently, regardless of improvement', default=DEFAULT_SAVE_ITERATIONS) | |
parser.add_argument('--blur', type=float, help='apply Gaussian blur with this sigma on every step', default=DEFAULT_BLUR) | |
return parser.parse_args() | |
if __name__ == '__main__': | |
start = time.time() | |
main() | |
end = time.time() | |
print(end - start) | |