joheras commited on
Commit
979d3e3
·
1 Parent(s): 7f7f14d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -0
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from scipy import ndimage
5
+ from IPython.display import Image
6
+
7
+ import tensorflow as tf
8
+ from tensorflow import keras
9
+ from tensorflow.keras import layers
10
+ from tensorflow.keras.applications import xception
11
+
12
+ # Size of the input image
13
+ img_size = (299, 299, 3)
14
+
15
+ # Load Xception model with imagenet weights
16
+ model = xception.Xception(weights="imagenet")
17
+
18
+ # The local path to our target image
19
+ img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")
20
+
21
+ def get_gradients(img_input, top_pred_idx):
22
+ """Computes the gradients of outputs w.r.t input image.
23
+
24
+ Args:
25
+ img_input: 4D image tensor
26
+ top_pred_idx: Predicted label for the input image
27
+
28
+ Returns:
29
+ Gradients of the predictions w.r.t img_input
30
+ """
31
+ images = tf.cast(img_input, tf.float32)
32
+
33
+ with tf.GradientTape() as tape:
34
+ tape.watch(images)
35
+ preds = model(images)
36
+ top_class = preds[:, top_pred_idx]
37
+
38
+ grads = tape.gradient(top_class, images)
39
+ return grads
40
+
41
+
42
+ def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):
43
+ """Computes Integrated Gradients for a predicted label.
44
+
45
+ Args:
46
+ img_input (ndarray): Original image
47
+ top_pred_idx: Predicted label for the input image
48
+ baseline (ndarray): The baseline image to start with for interpolation
49
+ num_steps: Number of interpolation steps between the baseline
50
+ and the input used in the computation of integrated gradients. These
51
+ steps along determine the integral approximation error. By default,
52
+ num_steps is set to 50.
53
+
54
+ Returns:
55
+ Integrated gradients w.r.t input image
56
+ """
57
+ # If baseline is not provided, start with a black image
58
+ # having same size as the input image.
59
+ if baseline is None:
60
+ baseline = np.zeros(img_size).astype(np.float32)
61
+ else:
62
+ baseline = baseline.astype(np.float32)
63
+
64
+ # 1. Do interpolation.
65
+ img_input = img_input.astype(np.float32)
66
+ interpolated_image = [
67
+ baseline + (step / num_steps) * (img_input - baseline)
68
+ for step in range(num_steps + 1)
69
+ ]
70
+ interpolated_image = np.array(interpolated_image).astype(np.float32)
71
+
72
+ # 2. Preprocess the interpolated images
73
+ interpolated_image = xception.preprocess_input(interpolated_image)
74
+
75
+ # 3. Get the gradients
76
+ grads = []
77
+ for i, img in enumerate(interpolated_image):
78
+ img = tf.expand_dims(img, axis=0)
79
+ grad = get_gradients(img, top_pred_idx=top_pred_idx)
80
+ grads.append(grad[0])
81
+ grads = tf.convert_to_tensor(grads, dtype=tf.float32)
82
+
83
+ # 4. Approximate the integral using the trapezoidal rule
84
+ grads = (grads[:-1] + grads[1:]) / 2.0
85
+ avg_grads = tf.reduce_mean(grads, axis=0)
86
+
87
+ # 5. Calculate integrated gradients and return
88
+ integrated_grads = (img_input - baseline) * avg_grads
89
+ return integrated_grads
90
+
91
+
92
+ def random_baseline_integrated_gradients(
93
+ img_input, top_pred_idx, num_steps=50, num_runs=2
94
+ ):
95
+ """Generates a number of random baseline images.
96
+
97
+ Args:
98
+ img_input (ndarray): 3D image
99
+ top_pred_idx: Predicted label for the input image
100
+ num_steps: Number of interpolation steps between the baseline
101
+ and the input used in the computation of integrated gradients. These
102
+ steps along determine the integral approximation error. By default,
103
+ num_steps is set to 50.
104
+ num_runs: number of baseline images to generate
105
+
106
+ Returns:
107
+ Averaged integrated gradients for `num_runs` baseline images
108
+ """
109
+ # 1. List to keep track of Integrated Gradients (IG) for all the images
110
+ integrated_grads = []
111
+
112
+ # 2. Get the integrated gradients for all the baselines
113
+ for run in range(num_runs):
114
+ baseline = np.random.random(img_size) * 255
115
+ igrads = get_integrated_gradients(
116
+ img_input=img_input,
117
+ top_pred_idx=top_pred_idx,
118
+ baseline=baseline,
119
+ num_steps=num_steps,
120
+ )
121
+ integrated_grads.append(igrads)
122
+
123
+ # 3. Return the average integrated gradients for the image
124
+ integrated_grads = tf.convert_to_tensor(integrated_grads)
125
+ return tf.reduce_mean(integrated_grads, axis=0)
126
+
127
+ class GradVisualizer:
128
+ """Plot gradients of the outputs w.r.t an input image."""
129
+
130
+ def __init__(self, positive_channel=None, negative_channel=None):
131
+ if positive_channel is None:
132
+ self.positive_channel = [0, 255, 0]
133
+ else:
134
+ self.positive_channel = positive_channel
135
+
136
+ if negative_channel is None:
137
+ self.negative_channel = [255, 0, 0]
138
+ else:
139
+ self.negative_channel = negative_channel
140
+
141
+ def apply_polarity(self, attributions, polarity):
142
+ if polarity == "positive":
143
+ return np.clip(attributions, 0, 1)
144
+ else:
145
+ return np.clip(attributions, -1, 0)
146
+
147
+ def apply_linear_transformation(
148
+ self,
149
+ attributions,
150
+ clip_above_percentile=99.9,
151
+ clip_below_percentile=70.0,
152
+ lower_end=0.2,
153
+ ):
154
+ # 1. Get the thresholds
155
+ m = self.get_thresholded_attributions(
156
+ attributions, percentage=100 - clip_above_percentile
157
+ )
158
+ e = self.get_thresholded_attributions(
159
+ attributions, percentage=100 - clip_below_percentile
160
+ )
161
+
162
+ # 2. Transform the attributions by a linear function f(x) = a*x + b such that
163
+ # f(m) = 1.0 and f(e) = lower_end
164
+ transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (
165
+ m - e
166
+ ) + lower_end
167
+
168
+ # 3. Make sure that the sign of transformed attributions is the same as original attributions
169
+ transformed_attributions *= np.sign(attributions)
170
+
171
+ # 4. Only keep values that are bigger than the lower_end
172
+ transformed_attributions *= transformed_attributions >= lower_end
173
+
174
+ # 5. Clip values and return
175
+ transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)
176
+ return transformed_attributions
177
+
178
+ def get_thresholded_attributions(self, attributions, percentage):
179
+ if percentage == 100.0:
180
+ return np.min(attributions)
181
+
182
+ # 1. Flatten the attributions
183
+ flatten_attr = attributions.flatten()
184
+
185
+ # 2. Get the sum of the attributions
186
+ total = np.sum(flatten_attr)
187
+
188
+ # 3. Sort the attributions from largest to smallest.
189
+ sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]
190
+
191
+ # 4. Calculate the percentage of the total sum that each attribution
192
+ # and the values about it contribute.
193
+ cum_sum = 100.0 * np.cumsum(sorted_attributions) / total
194
+
195
+ # 5. Threshold the attributions by the percentage
196
+ indices_to_consider = np.where(cum_sum >= percentage)[0][0]
197
+
198
+ # 6. Select the desired attributions and return
199
+ attributions = sorted_attributions[indices_to_consider]
200
+ return attributions
201
+
202
+ def binarize(self, attributions, threshold=0.001):
203
+ return attributions > threshold
204
+
205
+ def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):
206
+ closed = ndimage.grey_closing(attributions, structure=structure)
207
+ opened = ndimage.grey_opening(closed, structure=structure)
208
+ return opened
209
+
210
+ def draw_outlines(
211
+ self, attributions, percentage=90, connected_component_structure=np.ones((3, 3))
212
+ ):
213
+ # 1. Binarize the attributions.
214
+ attributions = self.binarize(attributions)
215
+
216
+ # 2. Fill the gaps
217
+ attributions = ndimage.binary_fill_holes(attributions)
218
+
219
+ # 3. Compute connected components
220
+ connected_components, num_comp = ndimage.measurements.label(
221
+ attributions, structure=connected_component_structure
222
+ )
223
+
224
+ # 4. Sum up the attributions for each component
225
+ total = np.sum(attributions[connected_components > 0])
226
+ component_sums = []
227
+ for comp in range(1, num_comp + 1):
228
+ mask = connected_components == comp
229
+ component_sum = np.sum(attributions[mask])
230
+ component_sums.append((component_sum, mask))
231
+
232
+ # 5. Compute the percentage of top components to keep
233
+ sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)
234
+ sorted_sums = list(zip(*sorted_sums_and_masks))[0]
235
+ cumulative_sorted_sums = np.cumsum(sorted_sums)
236
+ cutoff_threshold = percentage * total / 100
237
+ cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]
238
+ if cutoff_idx > 2:
239
+ cutoff_idx = 2
240
+
241
+ # 6. Set the values for the kept components
242
+ border_mask = np.zeros_like(attributions)
243
+ for i in range(cutoff_idx + 1):
244
+ border_mask[sorted_sums_and_masks[i][1]] = 1
245
+
246
+ # 7. Make the mask hollow and show only the border
247
+ eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)
248
+ border_mask[eroded_mask] = 0
249
+
250
+ # 8. Return the outlined mask
251
+ return border_mask
252
+
253
+ def process_grads(
254
+ self,
255
+ image,
256
+ attributions,
257
+ polarity="positive",
258
+ clip_above_percentile=99.9,
259
+ clip_below_percentile=0,
260
+ morphological_cleanup=False,
261
+ structure=np.ones((3, 3)),
262
+ outlines=False,
263
+ outlines_component_percentage=90,
264
+ overlay=True,
265
+ ):
266
+ if polarity not in ["positive", "negative"]:
267
+ raise ValueError(
268
+ f""" Allowed polarity values: 'positive' or 'negative'
269
+ but provided {polarity}"""
270
+ )
271
+ if clip_above_percentile < 0 or clip_above_percentile > 100:
272
+ raise ValueError("clip_above_percentile must be in [0, 100]")
273
+
274
+ if clip_below_percentile < 0 or clip_below_percentile > 100:
275
+ raise ValueError("clip_below_percentile must be in [0, 100]")
276
+
277
+ # 1. Apply polarity
278
+ if polarity == "positive":
279
+ attributions = self.apply_polarity(attributions, polarity=polarity)
280
+ channel = self.positive_channel
281
+ else:
282
+ attributions = self.apply_polarity(attributions, polarity=polarity)
283
+ attributions = np.abs(attributions)
284
+ channel = self.negative_channel
285
+
286
+ # 2. Take average over the channels
287
+ attributions = np.average(attributions, axis=2)
288
+
289
+ # 3. Apply linear transformation to the attributions
290
+ attributions = self.apply_linear_transformation(
291
+ attributions,
292
+ clip_above_percentile=clip_above_percentile,
293
+ clip_below_percentile=clip_below_percentile,
294
+ lower_end=0.0,
295
+ )
296
+
297
+ # 4. Cleanup
298
+ if morphological_cleanup:
299
+ attributions = self.morphological_cleanup_fn(
300
+ attributions, structure=structure
301
+ )
302
+ # 5. Draw the outlines
303
+ if outlines:
304
+ attributions = self.draw_outlines(
305
+ attributions, percentage=outlines_component_percentage
306
+ )
307
+
308
+ # 6. Expand the channel axis and convert to RGB
309
+ attributions = np.expand_dims(attributions, 2) * channel
310
+
311
+ # 7.Superimpose on the original image
312
+ if overlay:
313
+ attributions = np.clip((attributions * 0.8 + image), 0, 255)
314
+ return attributions
315
+
316
+ def visualize(
317
+ self,
318
+ image,
319
+ gradients,
320
+ integrated_gradients,
321
+ polarity="positive",
322
+ clip_above_percentile=99.9,
323
+ clip_below_percentile=0,
324
+ morphological_cleanup=False,
325
+ structure=np.ones((3, 3)),
326
+ outlines=False,
327
+ outlines_component_percentage=90,
328
+ overlay=True,
329
+ figsize=(15, 8),
330
+ ):
331
+ # 1. Make two copies of the original image
332
+ img1 = np.copy(image)
333
+ img2 = np.copy(image)
334
+
335
+ # 2. Process the normal gradients
336
+ grads_attr = self.process_grads(
337
+ image=img1,
338
+ attributions=gradients,
339
+ polarity=polarity,
340
+ clip_above_percentile=clip_above_percentile,
341
+ clip_below_percentile=clip_below_percentile,
342
+ morphological_cleanup=morphological_cleanup,
343
+ structure=structure,
344
+ outlines=outlines,
345
+ outlines_component_percentage=outlines_component_percentage,
346
+ overlay=overlay,
347
+ )
348
+
349
+ # 3. Process the integrated gradients
350
+ igrads_attr = self.process_grads(
351
+ image=img2,
352
+ attributions=integrated_gradients,
353
+ polarity=polarity,
354
+ clip_above_percentile=clip_above_percentile,
355
+ clip_below_percentile=clip_below_percentile,
356
+ morphological_cleanup=morphological_cleanup,
357
+ structure=structure,
358
+ outlines=outlines,
359
+ outlines_component_percentage=outlines_component_percentage,
360
+ overlay=overlay,
361
+ )
362
+
363
+ return igrads_attr.astype(np.uint8)
364
+
365
+ def classify_image(image):
366
+ img = np.expand_dims(image, axis=0)
367
+ orig_img = np.copy(img[0]).astype(np.uint8)
368
+ img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)
369
+ preds = model.predict(img_processed)
370
+ top_pred_idx = tf.argmax(preds[0])
371
+ print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])
372
+ grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)
373
+ igrads = random_baseline_integrated_gradients(
374
+ np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2)
375
+ vis = GradVisualizer()
376
+ img_grads = vis.visualize(
377
+ image=orig_img,
378
+ gradients=grads[0].numpy(),
379
+ integrated_gradients=igrads.numpy(),
380
+ clip_above_percentile=99,
381
+ clip_below_percentile=0,
382
+ )
383
+ return {labels[i]: float(prediction[i]) for i in range(100)}
384
+
385
+ image = gr.inputs.Image(shape=(299,299))
386
+ label = gr.outputs.Image()
387
+
388
+ iface = gr.Interface(classify_image,image,label,
389
+ #outputs=[
390
+ # gr.outputs.Textbox(label="Engine issue"),
391
+ # gr.outputs.Textbox(label="Engine issue score")],
392
+ examples=["elephant.jpg.jpg"],
393
+ title="Model interpretability with Integrated Gradients",
394
+ description = "Model for classifying images from the CIFAR dataset using a vision transformer trained with small data.",
395
+ article = "Author: <a href=\"https://huggingface.co/joheras\">Jónathan Heras</a>"
396
+ # examples = ["sample.csv"],
397
+ )
398
+
399
+
400
+ iface.launch()