Spaces:
Runtime error
Runtime error
| # Code taken and adapted from https://github.com/wagnermoritz/GSE | |
| import torch | |
| import torchvision | |
| import math | |
| import torch.nn.functional as F | |
| from vlm_eval.attacks.attack import Attack | |
| # required input size : batch_size x num_media x num_frames x channels x height x width | |
| class GSEAttack(Attack): | |
| def __init__(self, model, *args, mask_out='none',ver=False, img_range=(-1, 1), search_steps=4, | |
| targeted=False, sequential=False, search_factor=2, | |
| gb_size=5, sgm=1.5, mu=1, sigma=0.0025, iters=200, k_hat=10, | |
| q=0.25, **kwargs): | |
| ''' | |
| Implementation of the GSE attack. | |
| args: | |
| model: Callable, PyTorch classifier. | |
| mask_out: Masks out context images if set to context, query images if set to query and none if set to none. | |
| ver: Bool, print progress if True. | |
| img_range: Tuple of ints/floats, lower and upper bound of image | |
| entries. | |
| search_steps: Int, number of steps for line search on the trade-off | |
| parameter. | |
| targeted: Bool, given label is used as a target label if True. | |
| sequential: Bool, perturbations are computed sequentially for all | |
| images in the batch if True. For fair comparison to | |
| Homotopy attack. | |
| search_factor: Float, factor to increase/decrease the trade-off | |
| parameter until an upper/lower bound for the line search | |
| is found. | |
| gb_size: Odd int, size of the Gaussian blur kernel. | |
| sgm: Float, sigma of the gaussian blur kernel | |
| mu: Float, trade-off parameter for 2-norm regularization. | |
| sigma: Float, step size | |
| iters: Int, number of iterations. | |
| k_hat: Int, number of iterations before transitioning to NAG. | |
| q: Float, inverse of increase factor for adjust_lambda. | |
| ''' | |
| super().__init__(model, img_range=img_range, targeted=targeted) | |
| self.ver = ver | |
| self.search_steps = search_steps | |
| self.sequential = sequential | |
| self.search_factor = search_factor | |
| self.gb_size = gb_size | |
| self.sgm = sgm | |
| self.mu = mu | |
| self.sigma = sigma | |
| self.iters = iters | |
| self.k_hat = k_hat | |
| self.q = q | |
| if mask_out != 'none': | |
| self.mask_out = mask_out | |
| else: | |
| self.mask_out = None | |
| def adjust_lambda(self, lam, noise): | |
| ''' | |
| Adjust trade-off parameters (lambda) to update search space. | |
| ''' | |
| x = noise.detach().clone().abs().mean(dim=1, keepdim=True).sign() | |
| gb = torchvision.transforms.GaussianBlur((self.gb_size, self.gb_size), | |
| sigma=self.sgm) | |
| x = gb(x) + 1 | |
| x = torch.where(x == 1, self.q, x) | |
| lam /= x[:, 0, :, :] | |
| return lam | |
| def section_search(self, x, steps=50): | |
| ''' | |
| Section search for finding the maximal lambda such that the | |
| perturbation is non-zero after the first iteration. | |
| ''' | |
| noise = torch.zeros_like(x, requires_grad=True) # the shape of 'x' is batch_size x num_media x num_frames x Color x height x width | |
| loss = (-self.model(x + noise).sum() + self.mu | |
| * torch.norm(noise.view(x.size(1), x.size(3), x.size(4), x.size(5)), p=2, dim=(1,2,3)).sum()) | |
| grad = torch.autograd.grad(loss, [noise])[0].detach() | |
| noise.detach_() | |
| ones = torch.ones_like(x.view(x.size(1), x.size(3), x.size(4), x.size(5)))[:, 0, :, :] | |
| # define upper and lower bound for line search | |
| lb = torch.zeros((x.size(1),), dtype=torch.float, | |
| device=self.device).view(-1, 1, 1) | |
| ub = lb.clone() + 0.001 | |
| mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma, | |
| ones * ub * self.sigma), | |
| p=0, dim=(1,2,3)) != 0 | |
| while mask.any(): | |
| ub[mask] *= 2 | |
| mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma, | |
| ones * ub * self.sigma), | |
| p=0, dim=(1,2,3)) != 0 | |
| # perform search | |
| for _ in range(steps): | |
| cur = (ub + lb) / 2 | |
| mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma, | |
| ones * cur * self.sigma), | |
| p=0, dim=(1,2,3)) == 0 | |
| ub[mask] = cur[mask] | |
| mask = torch.logical_not(mask) | |
| lb[mask] = cur[mask] | |
| cur = (lb + ub).view(-1) / 2 | |
| return 0.01 * cur | |
| def __call__(self, x, y, *args, **kwargs): | |
| ''' | |
| Call the attack for a batch of images x or sequentially for all images | |
| in x depending on self.sequential. | |
| args: | |
| x: Tensor of shape [B, C, H, W], batch of images. | |
| y: Tensor of shape [B], batch of labels. | |
| Returns a tensor of the same shape as x containing adversarial examples | |
| ''' | |
| if self.sequential: | |
| result = x.clone() | |
| for i, (x_, y_) in enumerate(zip(x, y)): | |
| result[i] = self.perform_att(x_.unsqueeze(0), | |
| y_.unsqueeze(0), | |
| mu=self.mu, sigma=self.sigma, | |
| k_hat=self.k_hat).detach() | |
| return result | |
| else: | |
| return self.perform_att(x, y, mu=self.mu, sigma=self.sigma, | |
| k_hat=self.k_hat) | |
| def _set_mask(self, data): | |
| mask = torch.ones_like(data) | |
| if self.mask_out == 'context': | |
| mask[:, :-1, ...] = 0 | |
| elif self.mask_out == 'query': | |
| mask[:, -1, ...] = 0 | |
| elif isinstance(self.mask_out, int): | |
| mask[:, self.mask_out, ...] = 0 | |
| elif self.mask_out is None: | |
| pass | |
| else: | |
| raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') | |
| return mask | |
| def perform_att(self, x, mu, sigma, k_hat): | |
| ''' | |
| Perform GSE attack on a batch of images x with corresponding labels y. | |
| ''' | |
| x = x.to(self.device) | |
| B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] # Input is of the shape Batch x Num_media x num_frames x colors x height x width | |
| lams = self.section_search(x) | |
| mask_out = self._set_mask(x).view(B,C,H,W) | |
| # save x, y, and lams for resetting them at the beginning of every | |
| # section search step | |
| save_x = x.clone() | |
| save_lams = lams.clone() | |
| # upper and lower bounds for section learch | |
| ub_lams = torch.full_like(lams, torch.inf) | |
| lb_lams = torch.full_like(lams, 0.0) | |
| # tensor for saving succesful adversarial examples in inner loop | |
| result = x.clone() | |
| # tensor for saving best adversarial example so far | |
| result2 = x.clone() | |
| best_l0 = torch.full((B,), torch.inf, device=self.device).type(x.type()) | |
| # section search | |
| for step in range(self.search_steps): | |
| x = save_x.clone() | |
| lams = save_lams.clone() | |
| lam = torch.ones_like(x.view(B, C, H, W))[:, 0, :, :] * lams.view(-1, 1, 1) | |
| # tensor for tracking for which images adv. examples have been found | |
| active = torch.ones(B, dtype=bool, device=self.device) | |
| # set initial perturbation to zero | |
| noise = torch.zeros_like(x, requires_grad = True) | |
| noise_old = noise.clone() | |
| lr = 1 | |
| # attack | |
| for j in range(self.iters): | |
| if self.ver: | |
| print(f'\rSearch step {step + 1}/{self.search_steps}, ' + | |
| f'Prox.Grad. Iteration {j + 1}/{self.iters}, ' + | |
| f'Images left: {x.shape[1]}', end='') | |
| if len(x) == 0: | |
| break | |
| self.model.model.zero_grad() | |
| loss = (-self.model(x + noise).sum() + mu | |
| * (torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum()) | |
| noise_grad_data = torch.autograd.grad(loss, [noise])[0].detach().view(B, C, H, W) | |
| #print(f"{loss} {(torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum()}") | |
| with torch.no_grad(): | |
| noise_grad_data = noise_grad_data * mask_out # Mask_out shape B x C x H x W | |
| lr_ = (1 + math.sqrt(1 + 4 * lr**2)) / 2 | |
| if j == k_hat: | |
| lammask = (lam > lams.view(-1, 1, 1))[:, None, :, :] | |
| lammask = lammask.repeat(1, C, 1, 1) | |
| noise_old = noise.clone() | |
| if j < k_hat: | |
| noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W) | |
| noise = self.prox(noise.view(B, C, H, W), lam * sigma).view(1, B, 1, C, H, W) | |
| noise_tmp = noise.clone() | |
| noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old | |
| noise_old = noise_tmp.clone() | |
| lam = self.adjust_lambda(lam, noise.view(B, C, H, W)) | |
| else: | |
| noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W) | |
| noise_tmp = noise.clone() | |
| noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old | |
| noise_old = noise_tmp.clone() | |
| noise[lammask.view(1, B, 1, C, H, W)] = 0 | |
| # clamp adv. example to valid range | |
| x_adv = torch.clamp(x + noise, *self.img_range) | |
| noise = x_adv - x | |
| lr = lr_ | |
| noise.requires_grad = True | |
| # section search | |
| # no adv. example found => decrease upper bound and current lambda | |
| # adv. example found => save it if the "0-norm" is better than of the | |
| # previous adv. example, increase lower bound and current lambda | |
| for i in range(B): | |
| if active[i]: | |
| ub_lams[i] = save_lams[i] | |
| save_lams[i] = 0.95 * lb_lams[i] + 0.05 * save_lams[i] | |
| else: | |
| print("here") | |
| l0 = self.l20((result[i] - save_x[i]).unsqueeze(0)).to(self.device) | |
| if l0 < best_l0[i]: | |
| best_l0[i] = l0 | |
| result2[i] = result[i].clone() | |
| if torch.isinf(ub_lams[i]): | |
| lb_lams[i] = save_lams[i] | |
| save_lams[i] *= self.search_factor | |
| else: | |
| lb_lams[i] = save_lams[i] | |
| save_lams[i] = (ub_lams[i] + save_lams[i]) / 2 | |
| if self.ver: | |
| print('') | |
| return x_adv | |
| def extract_patches(self, x): | |
| ''' | |
| Extracts and returns all overlapping size by size patches from | |
| the image batch x. | |
| ''' | |
| B, C, _, _ = x.shape | |
| size = 8 | |
| kernel = torch.zeros((size ** 2, size ** 2)) | |
| kernel[range(size**2), range(size**2)] = 1.0 | |
| kernel = kernel.view(size**2, 1, size, size) | |
| kernel = kernel.repeat(C, 1, 1, 1).to(x.device) | |
| out = F.conv2d(x, kernel, groups=C) | |
| out = out.view(B, C, size, size, -1) | |
| out = out.permute(0, 4, 1, 2, 3) | |
| return out.contiguous() | |
| def l20(self, x): | |
| ''' | |
| Computes d_{2,0}(x[i]) for all perturbations x[i] in the batch x | |
| as described in section 3.2. | |
| ''' | |
| B, N, M, C, _, _ = x.shape | |
| l20s = [] | |
| for b in range(B): | |
| for n in range(N): | |
| for m in range(M): | |
| x_ = x[b, n, m] # Select the specific perturbation x[b, n, m] | |
| patches = self.extract_patches(x_.unsqueeze(0)) # Add unsqueeze to match 6D input | |
| l2s = torch.norm(patches, p=2, dim=(2,3,4)) | |
| l20s.append((l2s != 0).float().sum().item()) | |
| return torch.tensor(l20s) | |
| def prox(self, grad_loss_noise, lam): | |
| ''' | |
| Computes the proximal operator of the 1/2-norm of the gradient of the | |
| adversarial loss wrt current noise. | |
| ''' | |
| lam = lam[:, None, :, :] | |
| sh = list(grad_loss_noise.shape) | |
| lam = lam.expand(*sh) | |
| p_lam = (54 ** (1 / 3) / 4) * lam ** (2 / 3) | |
| mask1 = (grad_loss_noise > p_lam) | |
| mask2 = (torch.abs(grad_loss_noise) <= p_lam) | |
| mask3 = (grad_loss_noise < -p_lam) | |
| mask4 = mask1 + mask3 | |
| phi_lam_x = torch.arccos((lam / 8) * (torch.abs(grad_loss_noise) / 3) | |
| ** (-1.5)) | |
| grad_loss_noise[mask4] = ((2 / 3) * torch.abs(grad_loss_noise[mask4]) | |
| * (1 + torch.cos((2 * math.pi) / 3 | |
| - (2 * phi_lam_x[mask4]) / 3))).to(torch.float32) | |
| grad_loss_noise[mask3] = -grad_loss_noise[mask3] | |
| grad_loss_noise[mask2] = 0 | |
| return grad_loss_noise | |