KC123hello's picture
Upload Files
fc0ff8f verified
import torch
import torch.nn.functional as F
import collections.abc as container_abcs
# Code taken from https://github.com/chs20/RobustVLM/tree/main
# some parts of this code are adapted from
# https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py
def project_perturbation(perturbation, eps, norm):
if norm in ['inf', 'linf', 'Linf']:
pert_normalized = torch.clamp(perturbation, -eps, eps)
return pert_normalized
elif norm in [2, 2.0, 'l2', 'L2', '2']:
pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps)
return pert_normalized
else:
raise NotImplementedError(f'Norm {norm} not supported')
def normalize_grad(grad, p):
if p in ['inf', 'linf', 'Linf']:
return grad.sign()
elif p in [2, 2.0, 'l2', 'L2', '2']:
bs = grad.shape[0]
grad_flat = grad.view(bs, -1)
grad_normalized = F.normalize(grad_flat, p=2, dim=1)
return grad_normalized.view_as(grad)
def L1_norm(x, keepdim=False):
z = x.abs().view(x.shape[0], -1).sum(-1)
if keepdim:
z = z.view(-1, *[1]*(len(x.shape) - 1))
return z
def L2_norm(x, keepdim=False):
z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
if keepdim:
z = z.view(-1, *[1]*(len(x.shape) - 1))
return z
def L0_norm(x):
return (x != 0.).view(x.shape[0], -1).sum(-1)
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, container_abcs.Iterable):
for elem in x:
zero_gradients(elem)