Spaces:
Sleeping
Sleeping
import cvxpy as cp | |
from cvxpylayers.torch import CvxpyLayer | |
from torch.nn import functional as F | |
import torch | |
from modelguidedattacks import cls_models | |
import time | |
torch.manual_seed(0) | |
device = "cuda" | |
# model = cls_models.get_model("imagenet", "resnet18", device) | |
rand_feats = torch.randn(1, 512, device=device) | |
attack_targets = [4, 7, 5, 9, 2] | |
# # pred_logits = model.head(rand_feats) | |
# # head_W, head_bias = model.head_matrices() | |
(head_W, head_bias, pred_logits) = torch.load("debugsaveimagenet.save") | |
rand_feats, rand_logits, attack_targets = torch.load("attack_case.p", map_location=device) | |
reconstructed_logits = rand_feats@head_W.T + head_bias | |
num_feats = head_W.shape[1] | |
num_classes = head_W.shape[0] | |
x = cp.Variable(num_feats) | |
anchor_feats = cp.Parameter(x.shape) | |
A = cp.Parameter(head_W.shape) | |
b = cp.Parameter(head_bias.shape) | |
logits = A@x + b | |
MARGIN = 0.1 | |
# constraints = [] | |
# for i in range(len(attack_targets) - 1): | |
# constraints.append( logits[attack_targets[i]] - logits[attack_targets[i+1]] >= MARGIN) | |
# for i in range(num_classes): | |
# if i in attack_targets: | |
# continue | |
# constraints.append(logits[attack_targets[-1]] - logits[i] >= MARGIN ) | |
# objective = cp.Minimize(0.5 * cp.pnorm(x - anchor_feats, p=2)) | |
# problem = cp.Problem(objective, constraints) | |
# anchor_feats.value = rand_feats[0].cpu().numpy() | |
# A.value = head_W.detach().cpu().numpy() | |
# b.value = head_bias.detach().cpu().numpy() | |
# start_time = time.time() | |
# problem.solve() | |
# print ("Non vectorized sol", time.time() - start_time) | |
# logits_sol_torch = torch.from_numpy(logits.value) | |
# logits_check = logits_sol_torch.argsort(descending=True) | |
# feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats) | |
# sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1) | |
# sol_logits = head_W@feats_sol + head_bias[:, None] | |
# sol_sort = sol_logits.argsort(dim=0, descending=True) | |
# Constraint matrix | |
num_constraints = num_classes - 1 | |
D = torch.zeros((num_classes), num_constraints) | |
non_attack_targets = list(set(range(num_classes)) - set(attack_targets)) | |
for constraint_cursor in range(num_constraints): | |
if constraint_cursor < len(attack_targets) - 1: | |
D[attack_targets[constraint_cursor], constraint_cursor] = 1 | |
D[attack_targets[constraint_cursor + 1], constraint_cursor] = -1 | |
else: | |
non_attack_i = constraint_cursor - len(attack_targets) + 1 | |
D[attack_targets[-1], constraint_cursor] = 1 | |
D[non_attack_targets[non_attack_i], constraint_cursor] = -1 | |
D = D.T | |
# vectorized_differences = D @ logits | |
# vectorized_constraint = vectorized_differences >= torch.full(vectorized_differences.shape, fill_value=MARGIN).numpy() | |
# Q = 2*torch.eye(x.shape[0]).numpy() | |
# P = -2*anchor_feats | |
# G = D@A | |
# H = MARGIN - D @ b | |
# G = -G | |
# H = -H | |
# vectorized_constraint = G@x <= H | |
# objective = cp.Minimize((1/2)*cp.quad_form(x, Q) + P.T@x) | |
# problem = cp.Problem(objective, [vectorized_constraint]) | |
# anchor_feats.value = rand_feats[0].cpu().numpy() | |
# A.value = head_W.detach().cpu().numpy() | |
# b.value = head_bias.detach().cpu().numpy() | |
# start_time = time.time() | |
# problem.solve() | |
# print ("vectorized sol", time.time() - start_time) | |
# logits_sol_torch = torch.from_numpy(logits.value) | |
# logits_check = logits_sol_torch.argsort(descending=True) | |
# feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats) | |
# sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1) | |
# sol_logits = head_W@feats_sol + head_bias[:, None] | |
# sol_sort = sol_logits.argsort(dim=0, descending=True) | |
import qpth | |
B = 2 | |
nz = num_feats | |
nineq = num_constraints | |
device = "cuda" | |
attack_targets = attack_targets.expand(B, -1) | |
K = attack_targets.shape[-1] | |
# Start with all classes should be less than smallest attack target | |
D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1) | |
attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1) | |
D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device)) | |
# Clear out the constraint row for each item in the attack targets | |
attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1]) | |
D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device)) | |
batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1) | |
attack_targets_pos = attack_targets[:, :-1] # [B, K-1] | |
attack_targets_neg = attack_targets[:, 1:] # [B, K-1] | |
attack_targets_neg_inds = torch.stack(( | |
batch_inds, | |
attack_targets_neg, | |
attack_targets_neg | |
), dim=0) # [3, B, K - 1] | |
attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1) | |
D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1 | |
attack_targets_pos_inds = torch.stack(( | |
batch_inds, | |
attack_targets_neg, | |
attack_targets_pos | |
), dim=0) # [3, B, K - 1] | |
D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1 | |
A = head_W.detach().to(device) | |
b = head_bias.detach().to(device) | |
D = D.to(device) | |
#rand_feats: [B, num_features] | |
Q = 2*torch.eye(nz, device=device)[None].expand(B, -1, -1) | |
P = -2*rand_feats.to(device).expand(B, -1) | |
# G = torch.randn(B, nineq, nz, device=device) | |
G = -D@A | |
# h = torch.randn(B, nineq) | |
H = -(MARGIN - D @ b) | |
# Constraints are indexed by smaller logit | |
# First attack target isn't smaller than any logit, so its | |
# constraint index is redundant, but we keep it for easier parallelization | |
# Make this constraint all 0s | |
zero_inds = attack_targets[:, 0:1] # [B, 1] | |
H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device)) | |
e = torch.empty(0, device=device) | |
Q_t, P_t, G_t, H_t = torch.load("qpinputs.p", map_location=device) | |
z_sol = qpth.qp.QPFunction(verbose=True, check_Q_spd=False)(Q, P, G, H, e, e).T | |
logits = A@z_sol + b[:, None] | |
x = 5 |