|
import json |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input |
|
import shap |
|
import gradio as gr |
|
|
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
model = ResNet50(weights='imagenet') |
|
|
|
|
|
|
|
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json" |
|
with open(shap.datasets.cache(url)) as file: |
|
class_names = [v[1] for v in json.load(file).values()] |
|
|
|
def predict(x): |
|
"""prediction function for shap.""" |
|
tmp = np.expand_dims(x, axis=0) |
|
preprocess_input(tmp) |
|
pred = model(tmp) |
|
labels = list(zip(class_names, pred.numpy()[0].tolist())) |
|
return dict(sorted(labels, key=lambda t: t[1])[-10:]) |
|
|
|
def f(x): |
|
"""prediction function for shap because explainer does not expect batch dim.""" |
|
tmp = x.copy() |
|
preprocess_input(tmp) |
|
return model(tmp) |
|
|
|
def interpretation_function(img): |
|
|
|
masker = shap.maskers.Image("inpaint_telea", [224, 224, 3]) |
|
explainer = shap.PartitionExplainer(f, masker) |
|
|
|
|
|
|
|
pred = f(np.expand_dims(img, 0)).numpy().argmax() |
|
shap_values = explainer(np.expand_dims(img, 0), max_evals=10) |
|
|
|
|
|
scores = shap_values.values[0][:, :, :, pred] |
|
scores = scores.mean(axis=-1) |
|
max_val, min_val = np.max(scores), np.min(scores) |
|
scores = (scores - min_val) / (max_val - min_val) |
|
|
|
return {"original": gr.processing_utils.encode_array_to_base64(img), |
|
"interpretation": scores.tolist()} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(label="Input Image", shape=(224, 224)) |
|
with gr.Row(): |
|
classify = gr.Button("Classify") |
|
interpret = gr.Button("Interpret") |
|
with gr.Column(): |
|
label = gr.Label(label="Predicted Class") |
|
with gr.Column(): |
|
interpretation = gr.components.Interpretation(input_img) |
|
classify.click(predict, input_img, label) |
|
interpret.click(interpretation_function, input_img, interpretation) |
|
|
|
demo.launch() |
|
|