mawady commited on
Commit
f740d84
1 Parent(s): 745c1d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.applications.resnet_v2 import ResNet50V2
3
+ from tensorflow.keras.preprocessing import image
4
+ from tensorflow.keras.applications.resnet_v2 import preprocess_input, decode_predictions
5
+ import matplotlib.pyplot as plt
6
+ from alibi.explainers import IntegratedGradients
7
+ from alibi.datasets import load_cats
8
+ from alibi.utils.visualization import visualize_image_attr
9
+ import numpy as np
10
+ from PIL import Image
11
+ import io
12
+ import time
13
+ import os
14
+ import copy
15
+ import pickle
16
+ import datetime
17
+ import urllib.request
18
+ import gradio as gr
19
+
20
+
21
+ url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
22
+ path_input = "/content/cat.jpg"
23
+ urllib.request.urlretrieve(url, filename=path_input)
24
+
25
+ url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
26
+ path_input = "/content/dog.jpg"
27
+ urllib.request.urlretrieve(url, filename=path_input)
28
+
29
+ model = ResNet50V2(weights='imagenet')
30
+
31
+ n_steps = 50
32
+ method = "gausslegendre"
33
+ internal_batch_size = 50
34
+ ig = IntegratedGradients(model,
35
+ n_steps=n_steps,
36
+ method=method,
37
+ internal_batch_size=internal_batch_size)
38
+
39
+ # refs:
40
+ # - fig2pil: https://stackoverflow.com/questions/57316491/how-to-convert-matplotlib-figure-to-pil-image-object-without-saving-image
41
+ def do_process(img, baseline):
42
+ instance = image.img_to_array(img)
43
+ instance = np.expand_dims(instance, axis=0)
44
+ instance = preprocess_input(instance)
45
+ preds = model.predict(instance)
46
+ lstPreds = decode_predictions(preds, top=3)[0]
47
+ dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
48
+ predictions = preds.argmax(axis=1)
49
+ if baseline is 'white':
50
+ baselines = bls = np.ones(instance.shape).astype(instance.dtype)
51
+ elif baseline is 'black':
52
+ baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
53
+ else:
54
+ baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
55
+ explanation = ig.explain(instance,
56
+ baselines=baselines,
57
+ target=predictions)
58
+ attrs = explanation.attributions[0]
59
+ fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
60
+ sign='all', show_colorbar=True, title='Overlaid Attributions',
61
+ plt_fig_axis=None, use_pyplot=False)
62
+ buf = io.BytesIO()
63
+ fig.savefig(buf)
64
+ buf.seek(0)
65
+ img_res = Image.open(buf)
66
+ return img_res, dctPreds
67
+
68
+ input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB',
69
+ invert_colors=False, source="upload",
70
+ type="pil")
71
+ input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
72
+ choices=sorted(list(['black', 'white', 'random'])), default='random', type='value')
73
+
74
+ output_img = gr.outputs.Image(label='Output image', type='pil')
75
+ output_label = gr.outputs.Label(num_top_classes=3)
76
+
77
+ title = "XAI - Integrated gradients"
78
+ description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
79
+ examples = [['./cat.jpg'],['./dog.jpg']]
80
+ article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab recipes for computer vision - Dr. Mohamed Elawady</a></p>"
81
+ iface = gr.Interface(
82
+ fn=do_process,
83
+ inputs=[input_im, input_drop],
84
+ outputs=[output_img,output_label],
85
+ live=False,
86
+ interpretation=None,
87
+ title=title,
88
+ description=description,
89
+ article=article,
90
+ examples=examples
91
+ )
92
+
93
+ iface.test_launch()
94
+
95
+ iface.launch(share=True, debug=True)
96
+