fossil_app / explanations.py
andy-wyx's picture
feat:workbench page
86104a0
raw
history blame
5.07 kB
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
from labels import lookup_140
BATCH_SIZE = 1
def show(img, p=False, **kwargs):
img = np.array(img, dtype=np.float32)
# check if channel first
if img.shape[0] == 1:
img = img[0]
# check if cmap
if img.shape[-1] == 1:
img = img[:,:,0]
elif img.shape[-1] == 3:
img = img[:,:,::-1]
# normalize
if img.max() > 1 or img.min() < 0:
img -= img.min(); img/=img.max()
# check if clip percentile
if p is not False:
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
plt.imshow(img, **kwargs)
plt.axis('off')
#return img
def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) :
"""
Generate explanations for a given model and dataset.
:param model: The model to explain.
:param X: The dataset.
:param Y: The labels.
:param explainer: The explainer to use.
:param batch_size: The batch size to use.
:return: The explanations.
"""
print('using explain_method:',explain_method)
# we only need the classification part of the model
class_model = tf.keras.Model(model.input, model.output[1])
explainers = []
if explain_method=="Sobol":
explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
if explain_method=="HSIC":
explainers.append(HsicAttributionMethod(class_model,
grid_size=7, nb_design=1500,
sampler = LatinHypercube(binary=True)))
if explain_method=="Rise":
explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5))
if explain_method=="Saliency":
explainers.append(Saliency(class_model))
# explainers = [
# #Sobol, RISE, HSIC, Saliency
# #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
# #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
# #GradCAM(class_model),
# SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
# HsicAttributionMethod(class_model,
# grid_size=7, nb_design=1500,
# sampler = LatinHypercube(binary=True)),
# Saliency(class_model),
# Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
# preservation_probability=0.5),
# #
# ]
cropped,repetitions = _clever_crop(input_image,(size,size))
# size_repetitions = int(size//(repetitions.numpy()+1))
# print(size)
# print(type(input_image))
# print(input_image.shape)
# size_repetitions = int(size//(repetitions+1))
# print(type(repetitions))
# print(repetitions)
# print(size_repetitions)
# print(type(size_repetitions))
X = preprocess(cropped,size=size)
predictions = class_model.predict(np.array([X]))
#Y = np.argmax(predictions)
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
classes = []
for index in top_5_indices:
classes.append(lookup_140[index])
#print(top_5_indices)
X = np.expand_dims(X, 0)
explanations = []
for e,explainer in enumerate(explainers):
print(f'{e}/{len(explainers)}')
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
show(X[0])
show(phi, p=1, alpha=0.4)
# show(X[0][:,size_repetitions:2*size_repetitions,:])
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_{e}{i}.png')
explanations.append(f'phi_{e}{i}.png')
# avg=[]
# for i,Y in enumerate(top_5_indices):
# Y = tf.one_hot([Y], n_classes)
# print(f'{i}/{len(top_5_indices)}')
# phi = np.abs(explainer(X, Y))[0]
# if len(phi.shape) == 3:
# phi = np.mean(phi, -1)
# show(X[0][:,size_repetitions:2*size_repetitions,:])
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
# plt.savefig(f'phi_6.png')
# avg.append(f'phi_6.png')
print('Done')
if len(explanations)==1:
explanations = explanations[0]
# return explanations,avg
return classes,explanations