File size: 2,376 Bytes
05b4551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d4f54
05b4551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
import shap
import gradio as gr

import matplotlib
matplotlib.use('Agg')

model = ResNet50(weights='imagenet')



url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
    class_names = [v[1] for v in json.load(file).values()]

def predict(x):
    """prediction function for shap."""
    tmp = np.expand_dims(x, axis=0)
    preprocess_input(tmp)
    pred = model(tmp)
    labels = list(zip(class_names, pred.numpy()[0].tolist()))
    return dict(sorted(labels, key=lambda t: t[1])[-10:])

def f(x):
    """prediction function for shap because explainer does not expect batch dim."""
    tmp = x.copy()
    preprocess_input(tmp)
    return model(tmp)

def interpretation_function(img):
    # shap set up - see https://shap.readthedocs.io/en/latest/example_notebooks/image_examples/image_classification/Image%20Multi%20Class.html
    masker = shap.maskers.Image("inpaint_telea", [224, 224, 3])
    explainer = shap.PartitionExplainer(f, masker)
    
    # prediction to know which slice of shap value to get
    # shap interprets all classes
    pred = f(np.expand_dims(img, 0)).numpy().argmax()
    shap_values = explainer(np.expand_dims(img, 0), max_evals=10)

    # get shap values and return (224, 224)
    scores = shap_values.values[0][:, :, :, pred]
    scores = scores.mean(axis=-1)
    max_val, min_val = np.max(scores), np.min(scores)
    scores = (scores - min_val) / (max_val - min_val)
    # compute image to binary format so front-end can display
    return {"original": gr.processing_utils.encode_array_to_base64(img),
            "interpretation": scores.tolist()}


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Input Image", shape=(224, 224))
            with gr.Row():
                classify = gr.Button("Classify")
                interpret = gr.Button("Interpret")
        with gr.Column():
            label = gr.Label(label="Predicted Class")
        with gr.Column():
            interpretation = gr.components.Interpretation(input_img)
    classify.click(predict, input_img, label)
    interpret.click(interpretation_function, input_img, interpretation)

demo.launch()