import tensorflow as tf from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model from tensorflow.keras.applications.mobilenet_v2 import ( preprocess_input, decode_predictions, ) import matplotlib.pyplot as plt from alibi.explainers import IntegratedGradients from alibi.datasets import load_cats from alibi.utils.visualization import visualize_image_attr import numpy as np from PIL import Image, ImageFilter import io import time import os import copy import pickle import datetime import urllib.request import gradio as gr url = ( "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg" ) path_input = "./cat.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" path_input = "./dog.jpg" urllib.request.urlretrieve(url, filename=path_input) model = keras_model(weights="imagenet") n_steps = 50 method = "gausslegendre" internal_batch_size = 50 ig = IntegratedGradients( model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size ) def do_process(img, baseline): instance = image.img_to_array(img) instance = np.expand_dims(instance, axis=0) instance = preprocess_input(instance) preds = model.predict(instance) lstPreds = decode_predictions(preds, top=3)[0] dctPreds = { lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds)) } predictions = preds.argmax(axis=1) if baseline == "white": baselines = bls = np.ones(instance.shape).astype(instance.dtype) img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255)) elif baseline == "black": baselines = bls = np.zeros(instance.shape).astype(instance.dtype) img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255)) elif baseline == "blur": img_flt = img.filter(ImageFilter.GaussianBlur(5)) baselines = image.img_to_array(img_flt) baselines = np.expand_dims(baselines, axis=0) baselines = preprocess_input(baselines) else: baselines = np.random.random_sample(instance.shape).astype(instance.dtype) img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255)) explanation = ig.explain(instance, baselines=baselines, target=predictions) attrs = explanation.attributions[0] fig, ax = visualize_image_attr( attr=attrs.squeeze(), original_image=img, method="blended_heat_map", sign="all", show_colorbar=True, title=baseline, plt_fig_axis=None, use_pyplot=False, ) fig.tight_layout() buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img_res = Image.open(buf) return img_res, img_flt, dctPreds input_im = gr.inputs.Image( shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil" ) input_drop = gr.inputs.Dropdown( label="Baseline (default: random)", choices=["random", "black", "white", "blur"], default="random", type="value", ) output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil") output_base = gr.outputs.Image(label="Baseline image", type="pil") output_label = gr.outputs.Label(label="Classification results", num_top_classes=3) title = "XAI - Integrated gradients" description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio." examples = [["./cat.jpg", "blur"], ["./dog.jpg", "random"]] article = "

By Dr. Mohamed Elawady

" iface = gr.Interface( fn=do_process, inputs=[input_im, input_drop], outputs=[output_img, output_base, output_label], live=False, interpretation=None, title=title, description=description, article=article, examples=examples, ) iface.launch(debug=True)