import json import numpy as np import tensorflow as tf from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input import shap import gradio as gr import matplotlib matplotlib.use('Agg') model = ResNet50(weights='imagenet') url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json" with open(shap.datasets.cache(url)) as file: class_names = [v[1] for v in json.load(file).values()] def predict(x): """prediction function for shap.""" tmp = np.expand_dims(x, axis=0) preprocess_input(tmp) pred = model(tmp) labels = list(zip(class_names, pred.numpy()[0].tolist())) return dict(sorted(labels, key=lambda t: t[1])[-10:]) def f(x): """prediction function for shap because explainer does not expect batch dim.""" tmp = x.copy() preprocess_input(tmp) return model(tmp) def interpretation_function(img): # shap set up - see https://shap.readthedocs.io/en/latest/example_notebooks/image_examples/image_classification/Image%20Multi%20Class.html masker = shap.maskers.Image("inpaint_telea", [224, 224, 3]) explainer = shap.PartitionExplainer(f, masker) # prediction to know which slice of shap value to get # shap interprets all classes pred = f(np.expand_dims(img, 0)).numpy().argmax() shap_values = explainer(np.expand_dims(img, 0), max_evals=10) # get shap values and return (224, 224) scores = shap_values.values[0][:, :, :, pred] scores = scores.mean(axis=-1) max_val, min_val = np.max(scores), np.min(scores) scores = (scores - min_val) / (max_val - min_val) # compute image to binary format so front-end can display return {"original": gr.processing_utils.encode_array_to_base64(img), "interpretation": scores.tolist()} with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Image", shape=(224, 224)) with gr.Row(): classify = gr.Button("Classify") interpret = gr.Button("Interpret") with gr.Column(): label = gr.Label(label="Predicted Class") with gr.Column(): interpretation = gr.components.Interpretation(input_img) classify.click(predict, input_img, label) interpret.click(interpretation_function, input_img, interpretation) demo.launch()