PedroMartelleto commited on
Commit
1a23377
1 Parent(s): 2d26d4d

Explainability with SHAP

Browse files
Files changed (2) hide show
  1. app.py +6 -25
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  from torchvision.models import resnet50, ResNet50_Weights
3
- from torchvision import transforms
4
  import torch.nn as nn
5
  import torch
 
 
6
 
7
  @staticmethod
8
  def create_model_from_checkpoint():
@@ -13,36 +14,16 @@ def create_model_from_checkpoint():
13
  model.eval()
14
  return model
15
 
16
- def prep_image(img):
17
- transform = transforms.Compose([
18
- transforms.Resize(256),
19
- transforms.CenterCrop(224),
20
- transforms.ToTensor()
21
- ])
22
-
23
- transform_normalize = transforms.Normalize(
24
- mean=[0.485, 0.456, 0.406],
25
- std=[0.229, 0.224, 0.225]
26
- )
27
-
28
- transformed_img = transform(img)
29
-
30
- input = transform_normalize(transformed_img)
31
- input = input.unsqueeze(0)
32
- return input
33
-
34
  model = create_model_from_checkpoint()
35
  labels = [ "benign", "malignant", "normal" ]
36
 
37
  def predict(img):
38
- input = prep_image(img)
39
- with torch.no_grad():
40
- prediction = torch.nn.functional.softmax(model(input)[0], dim=0)
41
- confidences = {labels[i]: float(prediction[i]) for i in range(3)}
42
- return confidences
43
 
44
  ui = gr.Interface(fn=predict,
45
  inputs=gr.Image(type="pil"),
46
- outputs=gr.Label(num_top_classes=3),
47
  examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
48
  ui.launch(share=True)
 
1
  import gradio as gr
2
  from torchvision.models import resnet50, ResNet50_Weights
 
3
  import torch.nn as nn
4
  import torch
5
+ import numpy as np
6
+ from explain import Explainer
7
 
8
  @staticmethod
9
  def create_model_from_checkpoint():
 
14
  model.eval()
15
  return model
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  model = create_model_from_checkpoint()
18
  labels = [ "benign", "malignant", "normal" ]
19
 
20
  def predict(img):
21
+ explainer = Explainer(model, img, labels)
22
+ shap_img = explainer.shap()
23
+ return [explainer.confidences, shap_img]
 
 
24
 
25
  ui = gr.Interface(fn=predict,
26
  inputs=gr.Image(type="pil"),
27
+ outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")],
28
  examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
29
  ui.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ captum