import xplique import tensorflow as tf from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad, SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop, GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod) from xplique.attributions.global_sensitivity_analysis import LatinHypercube import numpy as np import matplotlib.pyplot as plt from inference_resnet import inference_resnet_finer, preprocess, _clever_crop from labels import lookup_140 BATCH_SIZE = 1 def show(img, p=False, **kwargs): img = np.array(img, dtype=np.float32) # check if channel first if img.shape[0] == 1: img = img[0] # check if cmap if img.shape[-1] == 1: img = img[:,:,0] elif img.shape[-1] == 3: img = img[:,:,::-1] # normalize if img.max() > 1 or img.min() < 0: img -= img.min(); img/=img.max() # check if clip percentile if p is not False: img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p)) plt.imshow(img, **kwargs) plt.axis('off') #return img def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) : """ Generate explanations for a given model and dataset. :param model: The model to explain. :param X: The dataset. :param Y: The labels. :param explainer: The explainer to use. :param batch_size: The batch size to use. :return: The explanations. """ print('using explain_method:',explain_method) # we only need the classification part of the model class_model = tf.keras.Model(model.input, model.output[1]) explainers = [] if explain_method=="Sobol": explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32)) if explain_method=="HSIC": explainers.append(HsicAttributionMethod(class_model, grid_size=7, nb_design=1500, sampler = LatinHypercube(binary=True))) if explain_method=="Rise": explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15, preservation_probability=0.5)) if explain_method=="Saliency": explainers.append(Saliency(class_model)) # explainers = [ # #Sobol, RISE, HSIC, Saliency # #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE), # #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE), # #GradCAM(class_model), # SobolAttributionMethod(class_model, grid_size=8, nb_design=32), # HsicAttributionMethod(class_model, # grid_size=7, nb_design=1500, # sampler = LatinHypercube(binary=True)), # Saliency(class_model), # Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15, # preservation_probability=0.5), # # # ] cropped,repetitions = _clever_crop(input_image,(size,size)) # size_repetitions = int(size//(repetitions.numpy()+1)) # print(size) # print(type(input_image)) # print(input_image.shape) # size_repetitions = int(size//(repetitions+1)) # print(type(repetitions)) # print(repetitions) # print(size_repetitions) # print(type(size_repetitions)) X = preprocess(cropped,size=size) predictions = class_model.predict(np.array([X])) #Y = np.argmax(predictions) top_5_indices = np.argsort(predictions[0])[-5:][::-1] classes = [] for index in top_5_indices: classes.append(lookup_140[index]) #print(top_5_indices) X = np.expand_dims(X, 0) explanations = [] for e,explainer in enumerate(explainers): print(f'{e}/{len(explainers)}') for i,Y in enumerate(top_5_indices): Y = tf.one_hot([Y], n_classes) print(f'{i}/{len(top_5_indices)}') phi = np.abs(explainer(X, Y))[0] if len(phi.shape) == 3: phi = np.mean(phi, -1) show(X[0]) show(phi, p=1, alpha=0.4) # show(X[0][:,size_repetitions:2*size_repetitions,:]) # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) plt.savefig(f'phi_{e}{i}.png') explanations.append(f'phi_{e}{i}.png') # avg=[] # for i,Y in enumerate(top_5_indices): # Y = tf.one_hot([Y], n_classes) # print(f'{i}/{len(top_5_indices)}') # phi = np.abs(explainer(X, Y))[0] # if len(phi.shape) == 3: # phi = np.mean(phi, -1) # show(X[0][:,size_repetitions:2*size_repetitions,:]) # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) # plt.savefig(f'phi_6.png') # avg.append(f'phi_6.png') print('Done') if len(explanations)==1: explanations = explanations[0] # return explanations,avg return classes,explanations