QuadAttack / debug_test.py
thomaspaniagua
QuadAttack release
71f183c
raw
history blame
5.88 kB
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
from torch.nn import functional as F
import torch
from modelguidedattacks import cls_models
import time
torch.manual_seed(0)
device = "cuda"
# model = cls_models.get_model("imagenet", "resnet18", device)
rand_feats = torch.randn(1, 512, device=device)
attack_targets = [4, 7, 5, 9, 2]
# # pred_logits = model.head(rand_feats)
# # head_W, head_bias = model.head_matrices()
(head_W, head_bias, pred_logits) = torch.load("debugsaveimagenet.save")
rand_feats, rand_logits, attack_targets = torch.load("attack_case.p", map_location=device)
reconstructed_logits = rand_feats@head_W.T + head_bias
num_feats = head_W.shape[1]
num_classes = head_W.shape[0]
x = cp.Variable(num_feats)
anchor_feats = cp.Parameter(x.shape)
A = cp.Parameter(head_W.shape)
b = cp.Parameter(head_bias.shape)
logits = A@x + b
MARGIN = 0.1
# constraints = []
# for i in range(len(attack_targets) - 1):
# constraints.append( logits[attack_targets[i]] - logits[attack_targets[i+1]] >= MARGIN)
# for i in range(num_classes):
# if i in attack_targets:
# continue
# constraints.append(logits[attack_targets[-1]] - logits[i] >= MARGIN )
# objective = cp.Minimize(0.5 * cp.pnorm(x - anchor_feats, p=2))
# problem = cp.Problem(objective, constraints)
# anchor_feats.value = rand_feats[0].cpu().numpy()
# A.value = head_W.detach().cpu().numpy()
# b.value = head_bias.detach().cpu().numpy()
# start_time = time.time()
# problem.solve()
# print ("Non vectorized sol", time.time() - start_time)
# logits_sol_torch = torch.from_numpy(logits.value)
# logits_check = logits_sol_torch.argsort(descending=True)
# feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats)
# sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1)
# sol_logits = head_W@feats_sol + head_bias[:, None]
# sol_sort = sol_logits.argsort(dim=0, descending=True)
# Constraint matrix
num_constraints = num_classes - 1
D = torch.zeros((num_classes), num_constraints)
non_attack_targets = list(set(range(num_classes)) - set(attack_targets))
for constraint_cursor in range(num_constraints):
if constraint_cursor < len(attack_targets) - 1:
D[attack_targets[constraint_cursor], constraint_cursor] = 1
D[attack_targets[constraint_cursor + 1], constraint_cursor] = -1
else:
non_attack_i = constraint_cursor - len(attack_targets) + 1
D[attack_targets[-1], constraint_cursor] = 1
D[non_attack_targets[non_attack_i], constraint_cursor] = -1
D = D.T
# vectorized_differences = D @ logits
# vectorized_constraint = vectorized_differences >= torch.full(vectorized_differences.shape, fill_value=MARGIN).numpy()
# Q = 2*torch.eye(x.shape[0]).numpy()
# P = -2*anchor_feats
# G = D@A
# H = MARGIN - D @ b
# G = -G
# H = -H
# vectorized_constraint = G@x <= H
# objective = cp.Minimize((1/2)*cp.quad_form(x, Q) + P.T@x)
# problem = cp.Problem(objective, [vectorized_constraint])
# anchor_feats.value = rand_feats[0].cpu().numpy()
# A.value = head_W.detach().cpu().numpy()
# b.value = head_bias.detach().cpu().numpy()
# start_time = time.time()
# problem.solve()
# print ("vectorized sol", time.time() - start_time)
# logits_sol_torch = torch.from_numpy(logits.value)
# logits_check = logits_sol_torch.argsort(descending=True)
# feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats)
# sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1)
# sol_logits = head_W@feats_sol + head_bias[:, None]
# sol_sort = sol_logits.argsort(dim=0, descending=True)
import qpth
B = 2
nz = num_feats
nineq = num_constraints
device = "cuda"
attack_targets = attack_targets.expand(B, -1)
K = attack_targets.shape[-1]
# Start with all classes should be less than smallest attack target
D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1)
attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1)
D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device))
# Clear out the constraint row for each item in the attack targets
attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1])
D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device))
batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1)
attack_targets_pos = attack_targets[:, :-1] # [B, K-1]
attack_targets_neg = attack_targets[:, 1:] # [B, K-1]
attack_targets_neg_inds = torch.stack((
batch_inds,
attack_targets_neg,
attack_targets_neg
), dim=0) # [3, B, K - 1]
attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1)
D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1
attack_targets_pos_inds = torch.stack((
batch_inds,
attack_targets_neg,
attack_targets_pos
), dim=0) # [3, B, K - 1]
D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1
A = head_W.detach().to(device)
b = head_bias.detach().to(device)
D = D.to(device)
#rand_feats: [B, num_features]
Q = 2*torch.eye(nz, device=device)[None].expand(B, -1, -1)
P = -2*rand_feats.to(device).expand(B, -1)
# G = torch.randn(B, nineq, nz, device=device)
G = -D@A
# h = torch.randn(B, nineq)
H = -(MARGIN - D @ b)
# Constraints are indexed by smaller logit
# First attack target isn't smaller than any logit, so its
# constraint index is redundant, but we keep it for easier parallelization
# Make this constraint all 0s
zero_inds = attack_targets[:, 0:1] # [B, 1]
H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device))
e = torch.empty(0, device=device)
Q_t, P_t, G_t, H_t = torch.load("qpinputs.p", map_location=device)
z_sol = qpth.qp.QPFunction(verbose=True, check_Q_spd=False)(Q, P, G, H, e, e).T
logits = A@z_sol + b[:, None]
x = 5