Spaces:
Runtime error
Runtime error
File size: 4,054 Bytes
f740d84 d57d2f2 f740d84 9c45667 f740d84 11bce97 f740d84 11bce97 f740d84 d57d2f2 f740d84 9c45667 f740d84 9c45667 f740d84 9c45667 f740d84 9c45667 f740d84 9c45667 dbacfc1 f740d84 9c45667 8057eaa f740d84 9c45667 f740d84 9730c7f 9c45667 9730c7f f740d84 df2ec0f f740d84 7b1981e f740d84 b037fad f740d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions
import matplotlib.pyplot as plt
from alibi.explainers import IntegratedGradients
from alibi.datasets import load_cats
from alibi.utils.visualization import visualize_image_attr
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr
url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)
model = keras_model(weights='imagenet')
n_steps = 50
method = "gausslegendre"
internal_batch_size = 50
ig = IntegratedGradients(model,
n_steps=n_steps,
method=method,
internal_batch_size=internal_batch_size)
def do_process(img, baseline):
instance = image.img_to_array(img)
instance = np.expand_dims(instance, axis=0)
instance = preprocess_input(instance)
preds = model.predict(instance)
lstPreds = decode_predictions(preds, top=3)[0]
dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
predictions = preds.argmax(axis=1)
if baseline == 'white':
baselines = bls = np.ones(instance.shape).astype(instance.dtype)
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
elif baseline == 'black':
baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
elif baseline == 'blur':
img_flt = img.filter(ImageFilter.GaussianBlur(5))
baselines = image.img_to_array(img_flt)
baselines = np.expand_dims(baselines, axis=0)
baselines = preprocess_input(baselines)
else:
baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
explanation = ig.explain(instance,
baselines=baselines,
target=predictions)
attrs = explanation.attributions[0]
fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
sign='all', show_colorbar=True, title=baseline,
plt_fig_axis=None, use_pyplot=False)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img_res = Image.open(buf)
return img_res, img_flt, dctPreds
input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB',
invert_colors=False, source="upload",
type="pil")
input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
choices=['random', 'black', 'white', 'blur'], default='random', type='value')
output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil')
output_base = gr.outputs.Image(label='Baseline image', type='pil')
output_label = gr.outputs.Label(label='Classification results', num_top_classes=3)
title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [['./cat.jpg', 'blur'],['./dog.jpg', 'random']]
article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab recipes for computer vision - Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
fn=do_process,
inputs=[input_im, input_drop],
outputs=[output_img,output_base,output_label],
live=False,
interpretation=None,
title=title,
description=description,
article=article,
examples=examples
)
iface.launch(debug=True)
|