fossil_app / explanations.py
andy-wyx's picture
debugging: xai output distortion
5579c05
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
from labels import lookup_140
import cv2
BATCH_SIZE = 1
def preprocess_image(image, output_size=(300, 300)):
#shape (height, width, channels)
h, w = image.shape[:2]
#padding
if h > w:
padding = (h - w) // 2
image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
else:
padding = (w - h) // 2
image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
# resize
image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
return image_resized
def transform(image, original_size,output_size):
"""
resize xai output back to original scale and pad to square-shape
"""
h,w = original_size
image = cv2.resize(image,(h,w), interpolation = cv2.INTER_AREA)
if h > w:
padding = (h - w) // 2
image= cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
else:
padding = (w - h) // 2
image = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
image = cv2.resize(image,output_size, interpolation = cv2.INTER_AREA)
return image
def show(img, original_size, output_size,p=False, **kwargs):
#img = preprocess_image(img, output_size=(output_size,output_size))
# check if channel first
if img.shape[0] == 1:
img = img[0]
# check if cmap
if img.shape[-1] == 1:
img = img[:,:,0]
elif img.shape[-1] == 3:
img = img[:,:,::-1]
# normalize
if img.max() > 1 or img.min() < 0:
img -= img.min(); img/=img.max()
# check if clip percentile
if p is not False:
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
img = transform(img,original_size=original_size,output_size=output_size)
plt.imshow(img, **kwargs)
plt.axis('off')
#return img
def explain(model, input_image,h,w,explain_method,nb_samples,size=600, n_classes=171) :
"""
Generate explanations for a given model and dataset.
:param model: The model to explain.
:param X: The dataset.
:param Y: The labels.
:param explainer: The explainer to use.
:param batch_size: The batch size to use.
:return: The explanations.
"""
print('using explain_method:',explain_method)
# we only need the classification part of the model
class_model = tf.keras.Model(model.input, model.output[1])
explainers = []
if explain_method=="Sobol":
explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
if explain_method=="HSIC":
explainers.append(HsicAttributionMethod(class_model,
grid_size=7, nb_design=1500,
sampler = LatinHypercube(binary=True)))
if explain_method=="Rise":
explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5))
if explain_method=="Saliency":
explainers.append(Saliency(class_model))
# explainers = [
# #Sobol, RISE, HSIC, Saliency
# #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
# #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
# #GradCAM(class_model),
# SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
# HsicAttributionMethod(class_model,
# grid_size=7, nb_design=1500,
# sampler = LatinHypercube(binary=True)),
# Saliency(class_model),
# Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
# preservation_probability=0.5),
# #
# ]
# cropped,repetitions = _clever_crop(input_image,(size,size))
# size_repetitions = int(size//(repetitions.numpy()+1))
# print(size)
# print(type(input_image))
# print(input_image.shape)
# size_repetitions = int(size//(repetitions+1))
# print(type(repetitions))
# print(repetitions)
# print(size_repetitions)
# print(type(size_repetitions))
# X = preprocess(cropped,size=size)
X = tf.image.resize(input_image, (size, size))
X = tf.reshape(X, (size, size, 3))/255
predictions = class_model.predict(np.array([X]))
#Y = np.argmax(predictions)
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
classes = []
for index in top_5_indices:
classes.append(lookup_140[index])
#print(top_5_indices)
X = np.expand_dims(X, 0)
explanations = []
for e,explainer in enumerate(explainers):
print(f'{e}/{len(explainers)}')
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
#apply Gaussian smoothing
phi_smoothed = cv2.GaussianBlur(phi, (5, 5), sigmaX=1.0, sigmaY=1.0)
show(X[0],original_size=(h,w),output_size = (size,size))
show(phi_smoothed, original_size=(h,w),output_size = (size,size),p=1, alpha=0.2)
# show(X[0][:,size_repetitions:2*size_repetitions,:])
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_{e}{i}.png')
explanations.append(f'phi_{e}{i}.png')
# avg=[]
# for i,Y in enumerate(top_5_indices):
# Y = tf.one_hot([Y], n_classes)
# print(f'{i}/{len(top_5_indices)}')
# phi = np.abs(explainer(X, Y))[0]
# if len(phi.shape) == 3:
# phi = np.mean(phi, -1)
# show(X[0][:,size_repetitions:2*size_repetitions,:])
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
# plt.savefig(f'phi_6.png')
# avg.append(f'phi_6.png')
print('Done')
if len(explanations)==1:
explanations = explanations[0]
# return explanations,avg
return classes,explanations