Spaces:
Runtime error
Runtime error
import PIL | |
from captum.attr import GradientShap | |
from captum.attr import visualization as viz | |
import torch | |
from torchvision import transforms | |
from matplotlib.colors import LinearSegmentedColormap | |
import torch.nn.functional as F | |
import gradio as gr | |
from torchvision.models import resnet50 | |
import torch.nn as nn | |
import torch | |
import numpy as np | |
class Explainer: | |
def __init__(self, model, img, class_names): | |
self.model = model | |
self.default_cmap = LinearSegmentedColormap.from_list('custom blue', | |
[(0, '#ffffff'), | |
(0.25, '#000000'), | |
(1, '#000000')], N=256) | |
self.class_names = class_names | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor() | |
]) | |
transform_normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
self.transformed_img = transform(img) | |
self.input = transform_normalize(self.transformed_img) | |
self.input = self.input.unsqueeze(0) | |
with torch.no_grad(): | |
self.output = self.model(self.input) | |
self.output = F.softmax(self.output, dim=1) | |
self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)} | |
self.pred_score, self.pred_label_idx = torch.topk(self.output, 1) | |
self.pred_label = self.class_names[self.pred_label_idx] | |
self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')' | |
def convert_fig_to_pil(self, fig): | |
return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
def shap(self): | |
gradient_shap = GradientShap(self.model) | |
rand_img_dist = torch.cat([self.input * 0, self.input * 1]) | |
attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx) | |
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)), | |
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)), | |
["original_image", "heat_map"], | |
["all", "absolute_value"], | |
cmap=self.default_cmap, | |
show_colorbar=True) | |
fig.suptitle(self.fig_title, fontsize=12) | |
return self.convert_fig_to_pil(fig) | |
def create_model_from_checkpoint(): | |
# Loads a model from a checkpoint | |
model = resnet50() | |
model.fc = nn.Linear(model.fc.in_features, 3) | |
model.load_state_dict(torch.load("best_model", map_location=torch.device('cpu'))) | |
model.eval() | |
return model | |
model = create_model_from_checkpoint() | |
labels = [ "benign", "malignant", "normal" ] | |
def predict(img): | |
explainer = Explainer(model, img, labels) | |
shap_img = explainer.shap() | |
return [explainer.confidences, shap_img] | |
ui = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")], | |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch() | |
ui.launch(share=True) |