xmutly's picture
Upload 294 files
e1aaaac verified
raw
history blame
4.88 kB
import tensorflow as tf
import numpy as np
import torch
class ModelAdapter():
def __init__(self, logits, x, y, sess, num_classes=10):
self.logits = logits
self.sess = sess
self.x_input = x
self.y_input = y
self.num_classes = num_classes
# gradients of logits
if num_classes <= 10:
self.grads = [None] * num_classes
for cl in range(num_classes):
self.grads[cl] = tf.gradients(self.logits[:, cl], self.x_input)[0]
# cross-entropy loss
self.xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self.logits, labels=self.y_input)
self.grad_xent = tf.gradients(self.xent, self.x_input)[0]
# dlr loss
self.dlr = dlr_loss(self.logits, self.y_input, num_classes=self.num_classes)
self.grad_dlr = tf.gradients(self.dlr, self.x_input)[0]
# targeted dlr loss
self.y_target = tf.placeholder(tf.int64, shape=[None])
self.dlr_target = dlr_loss_targeted(self.logits, self.y_input, self.y_target, num_classes=self.num_classes)
self.grad_target = tf.gradients(self.dlr_target, self.x_input)[0]
self.la = tf.placeholder(tf.int64, shape=[None])
self.la_target = tf.placeholder(tf.int64, shape=[None])
la_mask = tf.one_hot(self.la, self.num_classes)
la_target_mask = tf.one_hot(self.la_target, self.num_classes)
la_logit = tf.reduce_sum(la_mask * self.logits, axis=1)
la_target_logit = tf.reduce_sum(la_target_mask * self.logits, axis=1)
self.diff_logits = la_target_logit - la_logit
self.grad_diff_logits = tf.gradients(self.diff_logits, self.x_input)[0]
def predict(self, x):
x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
y = self.sess.run(self.logits, {self.x_input: x2})
return torch.from_numpy(y).cuda()
def grad_logits(self, x):
x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
logits, g2 = self.sess.run([self.logits, self.grads], {self.x_input: x2})
g2 = np.moveaxis(np.array(g2), 0, 1)
g2 = np.transpose(g2, (0, 1, 4, 2, 3))
return torch.from_numpy(logits).cuda(), torch.from_numpy(g2).cuda()
def get_grad_diff_logits_target(self, x, y=None, y_target=None):
la = y.cpu().numpy()
la_target = y_target.cpu().numpy()
x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
dl, g2 = self.sess.run([self.diff_logits, self.grad_diff_logits], {self.x_input: x2, self.la: la, self.la_target: la_target})
g2 = np.transpose(np.array(g2), (0, 3, 1, 2))
return torch.from_numpy(dl).cuda(), torch.from_numpy(g2).cuda()
def get_logits_loss_grad_xent(self, x, y):
x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
y2 = y.clone().cpu().numpy()
logits_val, loss_indiv_val, grad_val = self.sess.run([self.logits, self.xent, self.grad_xent], {self.x_input: x2, self.y_input: y2})
grad_val = np.moveaxis(grad_val, 3, 1)
return torch.from_numpy(logits_val).cuda(), torch.from_numpy(loss_indiv_val).cuda(), torch.from_numpy(grad_val).cuda()
def get_logits_loss_grad_dlr(self, x, y):
x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
y2 = y.clone().cpu().numpy()
logits_val, loss_indiv_val, grad_val = self.sess.run([self.logits, self.dlr, self.grad_dlr], {self.x_input: x2, self.y_input: y2})
grad_val = np.moveaxis(grad_val, 3, 1)
return torch.from_numpy(logits_val).cuda(), torch.from_numpy(loss_indiv_val).cuda(), torch.from_numpy(grad_val).cuda()
def get_logits_loss_grad_target(self, x, y, y_target):
x2 = np.moveaxis(x.cpu().numpy(), 1, 3)
y2 = y.clone().cpu().numpy()
y_targ = y_target.clone().cpu().numpy()
logits_val, loss_indiv_val, grad_val = self.sess.run([self.logits, self.dlr_target, self.grad_target], {self.x_input: x2, self.y_input: y2, self.y_target: y_targ})
grad_val = np.moveaxis(grad_val, 3, 1)
return torch.from_numpy(logits_val).cuda(), torch.from_numpy(loss_indiv_val).cuda(), torch.from_numpy(grad_val).cuda()
def dlr_loss(x, y, num_classes=10):
x_sort = tf.contrib.framework.sort(x, axis=1)
y_onehot = tf.one_hot(y, num_classes)
### TODO: adapt to the case when the point is already misclassified
loss = -(x_sort[:, -1] - x_sort[:, -2]) / (x_sort[:, -1] - x_sort[:, -3] + 1e-12)
return loss
def dlr_loss_targeted(x, y, y_target, num_classes=10):
x_sort = tf.contrib.framework.sort(x, axis=1)
y_onehot = tf.one_hot(y, num_classes)
y_target_onehot = tf.one_hot(y_target, num_classes)
loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12)
return loss