KC123hello's picture
Upload Files
fc0ff8f verified
# Code taken and adapted from https://github.com/wagnermoritz/GSE
from vlm_eval.attacks.attack import Attack
import torch
class SparseRS(Attack):
def __init__(self, model, *args, targeted=False, img_range=(-1, 1),
n_queries=10000, k=100, n_restarts=10, alpha_init=0.8, mask_out='none',**kwargs):
'''
Implementation of the L0 variant SparseRS https://arxiv.org/abs/2006.12834
Authors' implementation: https://github.com/fra31/sparse-rs
Adapted from: https://github.com/wagnermoritz/GSE/tree/main
args:
model: Callable, PyTorch classifier.
targeted: Bool, given label is used as a target label if True.
img_range: Tuple of ints/floats, lower and upper bound of image
entries.
n_queries: Int, max number of queries to the model
k: Int, initial sparsity parameter
n_restarts: Int, number of restarts with random initialization
alpha_init: Float, inital value for alpha schedule
'''
super().__init__(model, targeted=targeted, img_range=img_range)
self.n_queries = n_queries
self.k = k
self.n_restarts = n_restarts
self.alpha_init = alpha_init
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __call__(self, x, *args, **kwargs):
'''
Perform SparseRS L0 on a batch of images x with corresponding labels y.
args:
x: Tensor of shape [B, C, H, W], batch of images.
y: Tensor of shape [B], batch of labels.
Returns a tensor of the same shape as x containing adversarial examples
'''
for param in self.model.model.parameters():
param.requires_grad = False
torch.random.manual_seed(0)
torch.cuda.random.manual_seed(0)
x = x.to(self.device)
with torch.no_grad():
for _ in range(self.n_restarts):
if len(x) == 0:
break
x_adv = self.__perturb(x.clone())
return x_adv.detach()
def __perturb(self, x):
'''
Perform the attack from a random starting point.
'''
mask_out = self._set_mask(x)
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
batchidx = torch.arange(B, device=self.device).view(-1, 1)
result = x.clone().view(B, C, H, W)
# M: set of perturbed pixel indices, U_M: set of unperturbed pixel indices
batch_randperm = torch.rand(B, H * W, device=self.device).argsort(dim=1)
M = batch_randperm[:, :self.k]
U_M = batch_randperm[:, self.k:]
result[batchidx, :, M//W, M%H] = self.__sampleDelta(B, C, self.k)
best_loss = self.__lossfn(result.view(*x.shape))
for i in range(1, self.n_queries):
if B == 0:
break
# reset k_i currently perturbed pixels and perturb k_i new pixels
k_i = max(int(self.__alphaSchedule(i) * self.k), 1)
A_idx = torch.randperm(self.k, device=self.device)[:k_i]
B_idx = torch.randperm(H * W - self.k, device=self.device)[:k_i]
A_set, B_set = M[:, A_idx], U_M[:, B_idx]
z = result.clone()
z[batchidx, :, A_set//W, A_set%H] = x.view(B, C, H, W)[batchidx, :, A_set//W, A_set%H]
if k_i > 1:
z[batchidx, :, B_set//W, B_set%H] = self.__sampleDelta(B, C, k_i)
else: # if only one pixel is changed, make sure it actually changes
new_color = self.__sampleDelta(B, C, k_i)
while (mask := (z[batchidx, :, B_set//W, B_set%H] == new_color).view(B, -1).all(dim=-1)).any():
new_color[mask] = self.__sampleDelta(mask.int().sum().item(), C, k_i)
z[batchidx, :, B_set//W, B_set%H] = new_color
# save perturbations that improved the loss/margin
loss = self.__lossfn(z, y)
mask = loss < best_loss
best_loss[mask] = loss[mask]
mask = torch.logical_or(mask, margin < -1e-6)
if mask.any():
#best_margin[mask] = margin[mask]
tmp = result[active]
tmp[mask] = z[mask]
result[active] = tmp
U_M[mask.nonzero().view(-1, 1), B_idx] = A_set[mask]
M[mask.nonzero().view(-1, 1), A_idx] = B_set[mask]
# stop working on successful adv examples
mask = best_margin < 0
if mask.any():
mask = torch.logical_not(mask)
active[active.clone()] = mask
x, y, z, M, U_M = x[mask], y[mask], z[mask], M[mask], U_M[mask]
best_margin, best_loss = best_margin[mask], best_loss[mask]
B = len(y)
batchidx = torch.arange(B, device=self.device).view(-1, 1)
return result
def __sampleDelta(self, B, C, k):
'''
Sample k-pixel perturbations for B images. Each pixel is assigned a
random corner in the C-dimensional cube defined by self.img_range.
'''
fac = self.img_range[1] - self.img_range[0]
return self.img_range[0] + fac * torch.randint(0, 1, [B, k, C],
dtype=torch.float,
device=self.device)
def __alphaSchedule(self, iteration):
'''
Update number of pixels to perturb based in the current iteration.
'''
iteration = int(iteration / self.n_queries * 10000)
factors = [1, 2, 4, 5, 6, 8, 10, 12, 15, 20]
alpha_schedule = [10, 50, 200, 500, 1000, 2000, 4000, 6000, 8000]
idx = bisect.bisect_left(alpha_schedule, iteration)
return self.alpha_init / factors[idx]
def __lossfn(self, x):
'''
Compute the loss depending on self.targeted.
'''
return self.model(x).sum() if self.targeted else -self.model(x).sum()