File size: 5,554 Bytes
1173b78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Helper function for extracting features from pre-trained models
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from util.feature_extraction_utils import warp_image, normalize_batch
from util.prepare_utils import get_ensemble, extract_features
from lpips_pytorch import LPIPS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor_transform = transforms.ToTensor()
pil_transform = transforms.ToPILImage()
class Attack(nn.Module):
def __init__(
self,
models,
dim,
attack_type,
eps,
c_sim=0.5,
net_type="alex",
lr=0.05,
n_iters=100,
noise_size=0.001,
n_starts=10,
c_tv=None,
sigma_gf=None,
kernel_size_gf=None,
combination=False,
warp=False,
theta_warp=None,
V_reduction=None,
):
super(Attack, self).__init__()
self.extractor_ens = get_ensemble(
models, sigma_gf, kernel_size_gf, combination, V_reduction, warp, theta_warp
)
# print("There are '{}'' models in the attack ensemble".format(len(self.extractor_ens)))
self.dim = dim
self.eps = eps
self.c_sim = c_sim
self.net_type = net_type
self.lr = lr
self.n_iters = n_iters
self.noise_size = noise_size
self.n_starts = n_starts
self.c_tv = None
self.attack_type = attack_type
self.warp = warp
self.theta_warp = theta_warp
if self.attack_type == "lpips":
self.lpips_loss = LPIPS(self.net_type).to(device)
def execute(self, images, dir_vec, direction):
images = Variable(images).to(device)
dir_vec = dir_vec.to(device)
# take norm wrt dim
dir_vec_norm = dir_vec.norm(dim=2).unsqueeze(2).to(device)
dist = torch.zeros(images.shape[0]).to(device)
adv_images = images.detach().clone()
if self.warp:
self.face_img = warp_image(images, self.theta_warp)
for start in range(self.n_starts):
# update adversarial images old and distance old
adv_images_old = adv_images.detach().clone()
dist_old = dist.clone()
# add noise to initialize ( - noise_size, noise_size)
noise_uniform = Variable(
2 * self.noise_size * torch.rand(images.size()) - self.noise_size
).to(device)
adv_images = Variable(
images.detach().clone() + noise_uniform, requires_grad=True
).to(device)
for i in range(self.n_iters):
adv_features = extract_features(
adv_images, self.extractor_ens, self.dim
).to(device)
# normalize feature vectors in ensembles
loss = direction * torch.mean(
(adv_features - dir_vec) ** 2 / dir_vec_norm
)
if self.c_tv is not None:
tv_out = self.total_var_reg(images, adv_images)
loss -= self.c_tv * tv_out
if self.attack_type == "lpips":
lpips_out = self.lpips_reg(images, adv_images)
loss -= self.c_sim * lpips_out
grad = torch.autograd.grad(loss, [adv_images])
adv_images = adv_images + self.lr * grad[0].sign()
perturbation = adv_images - images
if self.attack_type == "sgd":
perturbation = torch.clamp(
perturbation, min=-self.eps, max=self.eps
)
adv_images = images + perturbation
adv_images = torch.clamp(adv_images, min=0, max=1)
adv_features = extract_features(
adv_images, self.extractor_ens, self.dim
).to(device)
dist = torch.mean((adv_features - dir_vec) ** 2 / dir_vec_norm, dim=[1, 2])
if direction == 1:
adv_images[dist < dist_old] = adv_images_old[dist < dist_old]
dist[dist < dist_old] = dist_old[dist < dist_old]
else:
adv_images[dist > dist_old] = adv_images_old[dist > dist_old]
dist[dist > dist_old] = dist_old[dist > dist_old]
return adv_images.detach().cpu()
def lpips_reg(self, images, adv_images):
if self.warp:
face_adv = warp_image(adv_images, self.theta_warp)
lpips_out = self.lpips_loss(
normalize_batch(self.face_img).to(device),
normalize_batch(face_adv).to(device),
)[0][0][0][0] / (2 * adv_images.shape[0])
lpips_out += self.lpips_loss(
normalize_batch(images).to(device),
normalize_batch(adv_images).to(device),
)[0][0][0][0] / (2 * adv_images.shape[0])
else:
lpips_out = (
self.lpips_loss(
normalize_batch(images).to(device),
normalize_batch(adv_images).to(device),
)[0][0][0][0]
/ adv_images.shape[0]
)
return lpips_out
def total_var_reg(images, adv_images):
perturbation = adv_images - images
tv = torch.mean(
torch.abs(perturbation[:, :, :, :-1] - perturbation[:, :, :, 1:])
) + torch.mean(
torch.abs(perturbation[:, :, :-1, :] - perturbation[:, :, 1:, :])
)
return tv
|