Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import collections.abc as container_abcs | |
| # Code taken from https://github.com/chs20/RobustVLM/tree/main | |
| # some parts of this code are adapted from | |
| # https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py | |
| def project_perturbation(perturbation, eps, norm): | |
| if norm in ['inf', 'linf', 'Linf']: | |
| pert_normalized = torch.clamp(perturbation, -eps, eps) | |
| return pert_normalized | |
| elif norm in [2, 2.0, 'l2', 'L2', '2']: | |
| pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps) | |
| return pert_normalized | |
| else: | |
| raise NotImplementedError(f'Norm {norm} not supported') | |
| def normalize_grad(grad, p): | |
| if p in ['inf', 'linf', 'Linf']: | |
| return grad.sign() | |
| elif p in [2, 2.0, 'l2', 'L2', '2']: | |
| bs = grad.shape[0] | |
| grad_flat = grad.view(bs, -1) | |
| grad_normalized = F.normalize(grad_flat, p=2, dim=1) | |
| return grad_normalized.view_as(grad) | |
| def L1_norm(x, keepdim=False): | |
| z = x.abs().view(x.shape[0], -1).sum(-1) | |
| if keepdim: | |
| z = z.view(-1, *[1]*(len(x.shape) - 1)) | |
| return z | |
| def L2_norm(x, keepdim=False): | |
| z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() | |
| if keepdim: | |
| z = z.view(-1, *[1]*(len(x.shape) - 1)) | |
| return z | |
| def L0_norm(x): | |
| return (x != 0.).view(x.shape[0], -1).sum(-1) | |
| def zero_gradients(x): | |
| if isinstance(x, torch.Tensor): | |
| if x.grad is not None: | |
| x.grad.detach_() | |
| x.grad.zero_() | |
| elif isinstance(x, container_abcs.Iterable): | |
| for elem in x: | |
| zero_gradients(elem) | |