freddyaboulton's picture
Fix quote
37d4f54
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
import shap
import gradio as gr
import matplotlib
matplotlib.use('Agg')
model = ResNet50(weights='imagenet')
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
class_names = [v[1] for v in json.load(file).values()]
def predict(x):
"""prediction function for shap."""
tmp = np.expand_dims(x, axis=0)
preprocess_input(tmp)
pred = model(tmp)
labels = list(zip(class_names, pred.numpy()[0].tolist()))
return dict(sorted(labels, key=lambda t: t[1])[-10:])
def f(x):
"""prediction function for shap because explainer does not expect batch dim."""
tmp = x.copy()
preprocess_input(tmp)
return model(tmp)
def interpretation_function(img):
# shap set up - see https://shap.readthedocs.io/en/latest/example_notebooks/image_examples/image_classification/Image%20Multi%20Class.html
masker = shap.maskers.Image("inpaint_telea", [224, 224, 3])
explainer = shap.PartitionExplainer(f, masker)
# prediction to know which slice of shap value to get
# shap interprets all classes
pred = f(np.expand_dims(img, 0)).numpy().argmax()
shap_values = explainer(np.expand_dims(img, 0), max_evals=10)
# get shap values and return (224, 224)
scores = shap_values.values[0][:, :, :, pred]
scores = scores.mean(axis=-1)
max_val, min_val = np.max(scores), np.min(scores)
scores = (scores - min_val) / (max_val - min_val)
# compute image to binary format so front-end can display
return {"original": gr.processing_utils.encode_array_to_base64(img),
"interpretation": scores.tolist()}
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Image", shape=(224, 224))
with gr.Row():
classify = gr.Button("Classify")
interpret = gr.Button("Interpret")
with gr.Column():
label = gr.Label(label="Predicted Class")
with gr.Column():
interpretation = gr.components.Interpretation(input_img)
classify.click(predict, input_img, label)
interpret.click(interpretation_function, input_img, interpretation)
demo.launch()