import PIL from captum.attr import GradientShap from captum.attr import visualization as viz import torch from torchvision import transforms from matplotlib.colors import LinearSegmentedColormap import torch.nn.functional as F import gradio as gr from torchvision.models import resnet50 import torch.nn as nn import torch import numpy as np class Explainer: def __init__(self, model): self.model = model self.default_cmap = LinearSegmentedColormap.from_list('custom blue', [(0, '#ffffff'), (0.25, '#000000'), (1, '#000000')], N=256) def __init__(self, model, img, class_names): self.model = model self.class_names = class_names transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor() ]) transform_normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) self.transformed_img = transform(img) self.input = transform_normalize(self.transformed_img) self.input = self.input.unsqueeze(0) with torch.no_grad(): self.output = self.model(self.input) self.output = F.softmax(self.output, dim=1) self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)} self.pred_score, self.pred_label_idx = torch.topk(self.output, 1) self.pred_label = self.class_names[self.pred_label_idx] self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')' def convert_fig_to_pil(self, fig): return PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) def shap(self): gradient_shap = GradientShap(self.model) rand_img_dist = torch.cat([self.input * 0, self.input * 1]) attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx) fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)), np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)), ["original_image", "heat_map"], ["all", "absolute_value"], cmap=self.default_cmap, show_colorbar=True) fig.suptitle(self.fig_title, fontsize=12) return self.convert_fig_to_pil(fig) def create_model_from_checkpoint(): # Loads a model from a checkpoint model = resnet50() model.fc = nn.Linear(model.fc.in_features, 3) model.load_state_dict(torch.load("best_model", map_location=torch.device('cpu'))) model.eval() return model model = create_model_from_checkpoint() labels = [ "benign", "malignant", "normal" ] def predict(img): explainer = Explainer(model, img, labels) shap_img = explainer.shap() return [explainer.confidences, shap_img] ui = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")], examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch() ui.launch(share=True)