|
import tensorflow as tf |
|
import numpy as np |
|
import torch |
|
|
|
class ModelAdapter(): |
|
def __init__(self, model, num_classes=10): |
|
""" |
|
Please note that model should be tf.keras model without activation function 'softmax' |
|
""" |
|
self.num_classes = num_classes |
|
self.tf_model = model |
|
self.data_format = self.__check_channel_ordering() |
|
|
|
def __tf_to_pt(self, tf_tensor): |
|
""" Private function |
|
Convert tf tensor to pt format |
|
|
|
Args: |
|
tf_tensor: (tf_tensor) TF tensor |
|
|
|
Retruns: |
|
pt_tensor: (pt_tensor) Pytorch tensor |
|
""" |
|
|
|
cpu_tensor = tf_tensor.numpy() |
|
pt_tensor = torch.from_numpy(cpu_tensor).cuda() |
|
|
|
return pt_tensor |
|
|
|
def set_data_format(self, data_format): |
|
""" |
|
Set data_format manually |
|
|
|
Args: |
|
data_format: A string, whose value should be either 'channels_last' or 'channels_first' |
|
""" |
|
|
|
if data_format != 'channels_last' or data_format != 'channels_first': |
|
raise ValueError("data_format should be either 'channels_last' or 'channels_first'") |
|
|
|
self.data_format = data_format |
|
|
|
|
|
def __check_channel_ordering(self): |
|
""" Private function |
|
Determinate TF model's channel ordering based on model's information. |
|
Default ordering is 'channels_last' in TF. |
|
However, 'channels_first' is used in Pytorch. |
|
|
|
Returns: |
|
data_format: A string, whose value should be either 'channels_last' or 'channels_first' |
|
""" |
|
|
|
data_format = None |
|
|
|
|
|
for L in self.tf_model.layers: |
|
if isinstance(L, tf.keras.layers.Conv2D): |
|
print("[INFO] set data_format = '{:s}'".format(L.data_format)) |
|
data_format = L.data_format |
|
break |
|
|
|
|
|
if data_format is None: |
|
print("[WARNING] Can not find Conv2D layer") |
|
input_shape = self.tf_model.input_shape |
|
|
|
|
|
if input_shape[3] == 3: |
|
print("[INFO] Because detecting input_shape[3] == 3, set data_format = 'channels_last'") |
|
data_format = 'channels_last' |
|
|
|
|
|
elif input_shape[3] == 1: |
|
print("[INFO] Because detecting input_shape[3] == 1, set data_format = 'channels_last'") |
|
data_format = 'channels_last' |
|
|
|
|
|
elif input_shape[1] == 3: |
|
print("[INFO] Because detecting input_shape[1] == 3, set data_format = 'channels_first'") |
|
data_format = 'channels_first' |
|
|
|
|
|
elif input_shape[1] == 1: |
|
print("[INFO] Because detecting input_shape[1] == 1, set data_format = 'channels_first'") |
|
data_format = 'channels_first' |
|
|
|
else: |
|
print("[ERROR] Unknow case") |
|
|
|
return data_format |
|
|
|
|
|
|
|
def __get_logits(self, x_input): |
|
""" Private function |
|
Get model's pre-softmax output in inference mode |
|
|
|
Args: |
|
x_input: (tf_tensor) Input data |
|
|
|
Returns: |
|
logits: (tf_tensor) Logits |
|
""" |
|
|
|
return self.tf_model(x_input, training=False) |
|
|
|
|
|
def __get_xent(self, logits, y_input): |
|
""" Private function |
|
Get cross entropy loss |
|
|
|
Args: |
|
logits: (tf_tensor) Logits. |
|
y_input: (tf_tensor) Label. |
|
|
|
Returns: |
|
xent: (tf_tensor) Cross entropy |
|
""" |
|
|
|
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_input) |
|
|
|
|
|
def __get_dlr(self, logit, y_input): |
|
""" Private function |
|
Get DLR loss |
|
|
|
Args: |
|
logit: (tf_tensor) Logits |
|
y_input: (tf_tensor) Input label |
|
|
|
Returns: |
|
loss: (tf_tensor) DLR loss |
|
""" |
|
|
|
|
|
logit_sort = tf.sort(logit, axis=1) |
|
|
|
|
|
y_onehot = tf.one_hot(y_input , self.num_classes, dtype=tf.float32) |
|
logit_y = tf.reduce_sum(y_onehot * logit, axis=1) |
|
|
|
|
|
logit_pred = tf.reduce_max(logit, axis=1) |
|
cond = (logit_pred == logit_y) |
|
z_i = tf.where(cond, logit_sort[:, -2], logit_sort[:, -1]) |
|
|
|
|
|
z_y = logit_y |
|
z_p1 = logit_sort[:, -1] |
|
z_p3 = logit_sort[:, -3] |
|
|
|
loss = - (z_y - z_i) / (z_p1 - z_p3 + 1e-12) |
|
return loss |
|
|
|
|
|
def __get_dlr_target(self, logits, y_input, y_target): |
|
""" Private function |
|
Get targeted version of DLR loss |
|
|
|
Args: |
|
logit: (tf_tensor) Logits |
|
y_input: (tf_tensor) Input label |
|
y_target: (tf_tensor) Input targeted label |
|
|
|
Returns: |
|
loss: (tf_tensor) Targeted DLR loss |
|
""" |
|
|
|
x = logits |
|
x_sort = tf.sort(x, axis=1) |
|
y_onehot = tf.one_hot(y_input, self.num_classes) |
|
y_target_onehot = tf.one_hot(y_target, self.num_classes) |
|
loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12) |
|
|
|
return loss |
|
|
|
|
|
|
|
@tf.function |
|
@tf.autograph.experimental.do_not_convert |
|
def __get_jacobian(self, x_input): |
|
""" Private function |
|
Get Jacoian |
|
|
|
Args: |
|
x_input: (tf_tensor) Input data |
|
|
|
Returns: |
|
jaconbian: (tf_tensor) Jacobian |
|
""" |
|
|
|
with tf.GradientTape(watch_accessed_variables=False) as g: |
|
g.watch(x_input) |
|
logits = self.__get_logits(x_input) |
|
|
|
jacobian = g.batch_jacobian(logits, x_input) |
|
|
|
return logits, jacobian |
|
|
|
|
|
@tf.function |
|
@tf.autograph.experimental.do_not_convert |
|
def __get_grad_xent(self, x_input, y_input): |
|
""" Private function |
|
Get gradient of cross entropy |
|
|
|
Args: |
|
x_input: (tf_tensor) Input data |
|
y_input: (tf_tensor) Input label |
|
|
|
Returns: |
|
logits: (tf_tensor) Logits |
|
xent: (tf_tensor) Cross entropy |
|
grad_xent: (tf_tensor) Gradient of cross entropy |
|
""" |
|
|
|
with tf.GradientTape(watch_accessed_variables=False) as g: |
|
g.watch(x_input) |
|
logits = self.__get_logits(x_input) |
|
xent = self.__get_xent(logits, y_input) |
|
|
|
grad_xent = g.gradient(xent, x_input) |
|
|
|
return logits, xent, grad_xent |
|
|
|
|
|
@tf.function |
|
@tf.autograph.experimental.do_not_convert |
|
def __get_grad_diff_logits_target(self, x, la, la_target): |
|
""" Private function |
|
Get difference of logits and corrospopnding gradient |
|
|
|
Args: |
|
x_input: (tf_tensor) Input data |
|
la: (tf_tensor) Input label |
|
la_target: (tf_tensor) Input targeted label |
|
|
|
Returns: |
|
difflogits: (tf_tensor) Difference of logits |
|
grad_diff: (tf_tensor) Gradient of difference of logits |
|
""" |
|
|
|
la_mask = tf.one_hot(la, self.num_classes) |
|
la_target_mask = tf.one_hot(la_target, self.num_classes) |
|
|
|
with tf.GradientTape(watch_accessed_variables=False) as g: |
|
g.watch(x) |
|
logits = self.__get_logits(x) |
|
difflogits = tf.reduce_sum((la_target_mask - la_mask) * logits, axis=1) |
|
|
|
grad_diff = g.gradient(difflogits, x) |
|
|
|
return difflogits, grad_diff |
|
|
|
|
|
@tf.function |
|
@tf.autograph.experimental.do_not_convert |
|
def __get_grad_dlr(self, x_input, y_input): |
|
""" Private function |
|
Get gradient of DLR loss |
|
|
|
Args: |
|
x_input: (tf_tensor) Input data |
|
y_input: (tf_tensor) Input label |
|
|
|
Returns: |
|
logits: (tf_tensor) Logits |
|
val_dlr: (tf_tensor) DLR loss |
|
grad_dlr: (tf_tensor) Gradient of DLR loss |
|
""" |
|
|
|
with tf.GradientTape(watch_accessed_variables=False) as g: |
|
g.watch(x_input) |
|
logits = self.__get_logits(x_input) |
|
val_dlr = self.__get_dlr(logits, y_input) |
|
|
|
grad_dlr = g.gradient(val_dlr, x_input) |
|
|
|
return logits, val_dlr, grad_dlr |
|
|
|
|
|
@tf.function |
|
@tf.autograph.experimental.do_not_convert |
|
def __get_grad_dlr_target(self, x_input, y_input, y_target): |
|
""" Private function |
|
Get gradient of targeted DLR loss |
|
|
|
Args: |
|
x_input: (tf_tensor) Input data |
|
y_input: (tf_tensor) Input label |
|
y_target: (tf_tensor) Input targeted label |
|
|
|
Returns: |
|
logits: (tf_tensor) Logits |
|
val_dlr: (tf_tensor) Targeted DLR loss |
|
grad_dlr: (tf_tensor) Gradient of targeted DLR loss |
|
""" |
|
|
|
with tf.GradientTape(watch_accessed_variables=False) as g: |
|
g.watch(x_input) |
|
logits = self.__get_logits(x_input) |
|
dlr_target = self.__get_dlr_target(logits, y_input, y_target) |
|
|
|
grad_target = g.gradient(dlr_target, x_input) |
|
|
|
return logits, dlr_target, grad_target |
|
|
|
|
|
|
|
def predict(self, x): |
|
""" |
|
Get model's pre-softmax output in inference mode |
|
|
|
Args: |
|
x_input: (pytorch_tensor) Input data |
|
|
|
Returns: |
|
y: (pytorch_tensor) Pre-softmax output |
|
""" |
|
|
|
|
|
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
|
if self.data_format == 'channels_last': |
|
x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
|
|
|
y = self.__get_logits(x2) |
|
|
|
|
|
y = self.__tf_to_pt(y) |
|
|
|
return y |
|
|
|
|
|
def grad_logits(self, x): |
|
""" |
|
Get logits and gradient of logits |
|
|
|
Args: |
|
x: (pytorch_tensor) Input data |
|
|
|
Returns: |
|
logits: (pytorch_tensor) Logits |
|
g2: (pytorch_tensor) Jacobian |
|
""" |
|
|
|
|
|
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
|
if self.data_format == 'channels_last': |
|
x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
|
|
|
logits, g2 = self.__get_jacobian(x2) |
|
|
|
|
|
if self.data_format == 'channels_last': |
|
g2 = tf.transpose(g2, perm=[0,1,4,2,3]) |
|
logits = self.__tf_to_pt(logits) |
|
g2 = self.__tf_to_pt(g2) |
|
|
|
return logits, g2 |
|
|
|
|
|
def get_logits_loss_grad_xent(self, x, y): |
|
""" |
|
Get gradient of cross entropy |
|
|
|
Args: |
|
x: (pytorch_tensor) Input data |
|
y: (pytorch_tensor) Input label |
|
|
|
Returns: |
|
logits_val: (pytorch_tensor) Logits |
|
loss_indiv_val: (pytorch_tensor) Cross entropy |
|
grad_val: (pytorch_tensor) Gradient of cross entropy |
|
""" |
|
|
|
|
|
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
|
y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
|
if self.data_format == 'channels_last': |
|
x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
|
|
|
logits_val, loss_indiv_val, grad_val = self.__get_grad_xent(x2, y2) |
|
|
|
|
|
if self.data_format == 'channels_last': |
|
grad_val = tf.transpose(grad_val, perm=[0,3,1,2]) |
|
logits_val = self.__tf_to_pt(logits_val) |
|
loss_indiv_val = self.__tf_to_pt(loss_indiv_val) |
|
grad_val = self.__tf_to_pt(grad_val) |
|
|
|
return logits_val, loss_indiv_val, grad_val |
|
|
|
|
|
def set_target_class(self, y, y_target): |
|
pass |
|
|
|
|
|
def get_grad_diff_logits_target(self, x, y, y_target): |
|
""" |
|
Get difference of logits and corrospopnding gradient |
|
|
|
Args: |
|
x: (pytorch_tensor) Input data |
|
y: (pytorch_tensor) Input label |
|
y_target: (pytorch_tensor) Input targeted label |
|
|
|
Returns: |
|
difflogits: (pytorch_tensor) Difference of logits |
|
g2: (pytorch_tensor) Gradient of difference of logits |
|
""" |
|
|
|
|
|
la = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
|
la_target = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32) |
|
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
|
if self.data_format == 'channels_last': |
|
x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
|
|
|
difflogits, g2 = self.__get_grad_diff_logits_target(x2, la, la_target) |
|
|
|
|
|
if self.data_format == 'channels_last': |
|
g2 = tf.transpose(g2, perm=[0, 3, 1, 2]) |
|
difflogits = self.__tf_to_pt(difflogits) |
|
g2 = self.__tf_to_pt(g2) |
|
|
|
return difflogits, g2 |
|
|
|
|
|
def get_logits_loss_grad_dlr(self, x, y): |
|
""" |
|
Get gradient of DLR loss |
|
|
|
Args: |
|
x: (pytorch_tensor) Input data |
|
y: (pytorch_tensor) Input label |
|
|
|
Returns: |
|
logits_val: (pytorch_tensor) Logits |
|
loss_indiv_val: (pytorch_tensor) DLR loss |
|
grad_val: (pytorch_tensor) Gradient of DLR loss |
|
""" |
|
|
|
|
|
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
|
y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
|
if self.data_format == 'channels_last': |
|
x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
|
|
|
logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr(x2, y2) |
|
|
|
|
|
if self.data_format == 'channels_last': |
|
grad_val = tf.transpose(grad_val, perm=[0,3,1,2]) |
|
logits_val = self.__tf_to_pt(logits_val) |
|
loss_indiv_val = self.__tf_to_pt(loss_indiv_val) |
|
grad_val = self.__tf_to_pt(grad_val) |
|
|
|
return logits_val, loss_indiv_val, grad_val |
|
|
|
def get_logits_loss_grad_target(self, x, y, y_target): |
|
""" |
|
Get gradient of targeted DLR loss |
|
|
|
Args: |
|
x: (pytorch_tensor) Input data |
|
y: (pytorch_tensor) Input label |
|
y_target: (pytorch_tensor) Input targeted label |
|
|
|
Returns: |
|
logits_val: (pytorch_tensor) Logits |
|
loss_indiv_val: (pytorch_tensor) Targeted DLR loss |
|
grad_val: (pytorch_tensor) Gradient of targeted DLR loss |
|
""" |
|
|
|
|
|
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
|
y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
|
y_targ = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32) |
|
if self.data_format == 'channels_last': |
|
x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
|
|
|
logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr_target(x2, y2, y_targ) |
|
|
|
|
|
if self.data_format == 'channels_last': |
|
grad_val = tf.transpose(grad_val, perm=[0,3,1,2]) |
|
logits_val = self.__tf_to_pt(logits_val) |
|
loss_indiv_val = self.__tf_to_pt(loss_indiv_val) |
|
grad_val = self.__tf_to_pt(grad_val) |
|
|
|
return logits_val, loss_indiv_val, grad_val |
|
|