xmutly's picture
Upload 294 files
e1aaaac verified
import tensorflow as tf
import numpy as np
import torch
class ModelAdapter():
def __init__(self, model, num_classes=10):
"""
Please note that model should be tf.keras model without activation function 'softmax'
"""
self.num_classes = num_classes
self.tf_model = model
self.data_format = self.__check_channel_ordering()
def __tf_to_pt(self, tf_tensor):
""" Private function
Convert tf tensor to pt format
Args:
tf_tensor: (tf_tensor) TF tensor
Retruns:
pt_tensor: (pt_tensor) Pytorch tensor
"""
cpu_tensor = tf_tensor.numpy()
pt_tensor = torch.from_numpy(cpu_tensor).cuda()
return pt_tensor
def set_data_format(self, data_format):
"""
Set data_format manually
Args:
data_format: A string, whose value should be either 'channels_last' or 'channels_first'
"""
if data_format != 'channels_last' or data_format != 'channels_first':
raise ValueError("data_format should be either 'channels_last' or 'channels_first'")
self.data_format = data_format
def __check_channel_ordering(self):
""" Private function
Determinate TF model's channel ordering based on model's information.
Default ordering is 'channels_last' in TF.
However, 'channels_first' is used in Pytorch.
Returns:
data_format: A string, whose value should be either 'channels_last' or 'channels_first'
"""
data_format = None
# Get the ordering of the dimensions in data from TF model
for L in self.tf_model.layers:
if isinstance(L, tf.keras.layers.Conv2D):
print("[INFO] set data_format = '{:s}'".format(L.data_format))
data_format = L.data_format
break
# Guess the ordering of the dimensions in data by input dimensions which sould be 4-D tensor
if data_format is None:
print("[WARNING] Can not find Conv2D layer")
input_shape = self.tf_model.input_shape
# Assume that input is *colorful image* whose dimensions should be [batch_size, img_w, img_h, 3]
if input_shape[3] == 3:
print("[INFO] Because detecting input_shape[3] == 3, set data_format = 'channels_last'")
data_format = 'channels_last'
# Assume that input is *gray image* whose dimensions should be [batch_size, img_w, img_h, 1]
elif input_shape[3] == 1:
print("[INFO] Because detecting input_shape[3] == 1, set data_format = 'channels_last'")
data_format = 'channels_last'
# Assume that input is *colorful image* whose dimensions should be [batch_size, 3, img_w, img_h]
elif input_shape[1] == 3:
print("[INFO] Because detecting input_shape[1] == 3, set data_format = 'channels_first'")
data_format = 'channels_first'
# Assume that input is *gray image* whose dimensions should be [batch_size, 1, img_w, img_h]
elif input_shape[1] == 1:
print("[INFO] Because detecting input_shape[1] == 1, set data_format = 'channels_first'")
data_format = 'channels_first'
else:
print("[ERROR] Unknow case")
return data_format
# Common function which may be called in tf.function #
def __get_logits(self, x_input):
""" Private function
Get model's pre-softmax output in inference mode
Args:
x_input: (tf_tensor) Input data
Returns:
logits: (tf_tensor) Logits
"""
return self.tf_model(x_input, training=False)
def __get_xent(self, logits, y_input):
""" Private function
Get cross entropy loss
Args:
logits: (tf_tensor) Logits.
y_input: (tf_tensor) Label.
Returns:
xent: (tf_tensor) Cross entropy
"""
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_input)
def __get_dlr(self, logit, y_input):
""" Private function
Get DLR loss
Args:
logit: (tf_tensor) Logits
y_input: (tf_tensor) Input label
Returns:
loss: (tf_tensor) DLR loss
"""
# logit
logit_sort = tf.sort(logit, axis=1)
# onthot_y
y_onehot = tf.one_hot(y_input , self.num_classes, dtype=tf.float32)
logit_y = tf.reduce_sum(y_onehot * logit, axis=1)
# z_i
logit_pred = tf.reduce_max(logit, axis=1)
cond = (logit_pred == logit_y)
z_i = tf.where(cond, logit_sort[:, -2], logit_sort[:, -1])
# loss
z_y = logit_y
z_p1 = logit_sort[:, -1]
z_p3 = logit_sort[:, -3]
loss = - (z_y - z_i) / (z_p1 - z_p3 + 1e-12)
return loss
def __get_dlr_target(self, logits, y_input, y_target):
""" Private function
Get targeted version of DLR loss
Args:
logit: (tf_tensor) Logits
y_input: (tf_tensor) Input label
y_target: (tf_tensor) Input targeted label
Returns:
loss: (tf_tensor) Targeted DLR loss
"""
x = logits
x_sort = tf.sort(x, axis=1)
y_onehot = tf.one_hot(y_input, self.num_classes)
y_target_onehot = tf.one_hot(y_target, self.num_classes)
loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12)
return loss
# function called by public API directly #
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_jacobian(self, x_input):
""" Private function
Get Jacoian
Args:
x_input: (tf_tensor) Input data
Returns:
jaconbian: (tf_tensor) Jacobian
"""
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
jacobian = g.batch_jacobian(logits, x_input)
return logits, jacobian
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_xent(self, x_input, y_input):
""" Private function
Get gradient of cross entropy
Args:
x_input: (tf_tensor) Input data
y_input: (tf_tensor) Input label
Returns:
logits: (tf_tensor) Logits
xent: (tf_tensor) Cross entropy
grad_xent: (tf_tensor) Gradient of cross entropy
"""
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
xent = self.__get_xent(logits, y_input)
grad_xent = g.gradient(xent, x_input)
return logits, xent, grad_xent
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_diff_logits_target(self, x, la, la_target):
""" Private function
Get difference of logits and corrospopnding gradient
Args:
x_input: (tf_tensor) Input data
la: (tf_tensor) Input label
la_target: (tf_tensor) Input targeted label
Returns:
difflogits: (tf_tensor) Difference of logits
grad_diff: (tf_tensor) Gradient of difference of logits
"""
la_mask = tf.one_hot(la, self.num_classes)
la_target_mask = tf.one_hot(la_target, self.num_classes)
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x)
logits = self.__get_logits(x)
difflogits = tf.reduce_sum((la_target_mask - la_mask) * logits, axis=1)
grad_diff = g.gradient(difflogits, x)
return difflogits, grad_diff
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_dlr(self, x_input, y_input):
""" Private function
Get gradient of DLR loss
Args:
x_input: (tf_tensor) Input data
y_input: (tf_tensor) Input label
Returns:
logits: (tf_tensor) Logits
val_dlr: (tf_tensor) DLR loss
grad_dlr: (tf_tensor) Gradient of DLR loss
"""
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
val_dlr = self.__get_dlr(logits, y_input)
grad_dlr = g.gradient(val_dlr, x_input)
return logits, val_dlr, grad_dlr
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_dlr_target(self, x_input, y_input, y_target):
""" Private function
Get gradient of targeted DLR loss
Args:
x_input: (tf_tensor) Input data
y_input: (tf_tensor) Input label
y_target: (tf_tensor) Input targeted label
Returns:
logits: (tf_tensor) Logits
val_dlr: (tf_tensor) Targeted DLR loss
grad_dlr: (tf_tensor) Gradient of targeted DLR loss
"""
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
dlr_target = self.__get_dlr_target(logits, y_input, y_target)
grad_target = g.gradient(dlr_target, x_input)
return logits, dlr_target, grad_target
# Public API #
def predict(self, x):
"""
Get model's pre-softmax output in inference mode
Args:
x_input: (pytorch_tensor) Input data
Returns:
y: (pytorch_tensor) Pre-softmax output
"""
# Convert pt_tensor to tf format
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
# Get result
y = self.__get_logits(x2)
# Convert result to pt format
y = self.__tf_to_pt(y)
return y
def grad_logits(self, x):
"""
Get logits and gradient of logits
Args:
x: (pytorch_tensor) Input data
Returns:
logits: (pytorch_tensor) Logits
g2: (pytorch_tensor) Jacobian
"""
# Convert pt_tensor to tf format
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
# Get result
logits, g2 = self.__get_jacobian(x2)
# Convert result to pt format
if self.data_format == 'channels_last':
g2 = tf.transpose(g2, perm=[0,1,4,2,3])
logits = self.__tf_to_pt(logits)
g2 = self.__tf_to_pt(g2)
return logits, g2
def get_logits_loss_grad_xent(self, x, y):
"""
Get gradient of cross entropy
Args:
x: (pytorch_tensor) Input data
y: (pytorch_tensor) Input label
Returns:
logits_val: (pytorch_tensor) Logits
loss_indiv_val: (pytorch_tensor) Cross entropy
grad_val: (pytorch_tensor) Gradient of cross entropy
"""
# Convert pt_tensor to tf format
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
# Get result
logits_val, loss_indiv_val, grad_val = self.__get_grad_xent(x2, y2)
# Convert result to pt format
if self.data_format == 'channels_last':
grad_val = tf.transpose(grad_val, perm=[0,3,1,2])
logits_val = self.__tf_to_pt(logits_val)
loss_indiv_val = self.__tf_to_pt(loss_indiv_val)
grad_val = self.__tf_to_pt(grad_val)
return logits_val, loss_indiv_val, grad_val
def set_target_class(self, y, y_target):
pass
def get_grad_diff_logits_target(self, x, y, y_target):
"""
Get difference of logits and corrospopnding gradient
Args:
x: (pytorch_tensor) Input data
y: (pytorch_tensor) Input label
y_target: (pytorch_tensor) Input targeted label
Returns:
difflogits: (pytorch_tensor) Difference of logits
g2: (pytorch_tensor) Gradient of difference of logits
"""
# Convert pt_tensor to tf format
la = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
la_target = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32)
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
# Get result
difflogits, g2 = self.__get_grad_diff_logits_target(x2, la, la_target)
# Convert result to pt format
if self.data_format == 'channels_last':
g2 = tf.transpose(g2, perm=[0, 3, 1, 2])
difflogits = self.__tf_to_pt(difflogits)
g2 = self.__tf_to_pt(g2)
return difflogits, g2
def get_logits_loss_grad_dlr(self, x, y):
"""
Get gradient of DLR loss
Args:
x: (pytorch_tensor) Input data
y: (pytorch_tensor) Input label
Returns:
logits_val: (pytorch_tensor) Logits
loss_indiv_val: (pytorch_tensor) DLR loss
grad_val: (pytorch_tensor) Gradient of DLR loss
"""
# Convert pt_tensor to tf format
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
# Get result
logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr(x2, y2)
# Convert result to pt format
if self.data_format == 'channels_last':
grad_val = tf.transpose(grad_val, perm=[0,3,1,2])
logits_val = self.__tf_to_pt(logits_val)
loss_indiv_val = self.__tf_to_pt(loss_indiv_val)
grad_val = self.__tf_to_pt(grad_val)
return logits_val, loss_indiv_val, grad_val
def get_logits_loss_grad_target(self, x, y, y_target):
"""
Get gradient of targeted DLR loss
Args:
x: (pytorch_tensor) Input data
y: (pytorch_tensor) Input label
y_target: (pytorch_tensor) Input targeted label
Returns:
logits_val: (pytorch_tensor) Logits
loss_indiv_val: (pytorch_tensor) Targeted DLR loss
grad_val: (pytorch_tensor) Gradient of targeted DLR loss
"""
# Convert pt_tensor to tf format
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
y_targ = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
# Get result
logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr_target(x2, y2, y_targ)
# Convert result to pt format
if self.data_format == 'channels_last':
grad_val = tf.transpose(grad_val, perm=[0,3,1,2])
logits_val = self.__tf_to_pt(logits_val)
loss_indiv_val = self.__tf_to_pt(loss_indiv_val)
grad_val = self.__tf_to_pt(grad_val)
return logits_val, loss_indiv_val, grad_val