PedroMartelleto commited on
Commit
c396e65
1 Parent(s): 1b87171

Deploying to HF

Browse files
Files changed (1) hide show
  1. app.py +55 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import PIL
2
- from captum.attr import GradientShap
3
  from captum.attr import visualization as viz
4
  import torch
5
  from torchvision import transforms
@@ -65,6 +65,58 @@ class Explainer:
65
  fig.suptitle(self.fig_title, fontsize=12)
66
  return self.convert_fig_to_pil(fig)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def create_model_from_checkpoint():
69
  # Loads a model from a checkpoint
70
  model = resnet50()
@@ -78,11 +130,10 @@ labels = [ "benign", "malignant", "normal" ]
78
 
79
  def predict(img):
80
  explainer = Explainer(model, img, labels)
81
- shap_img = explainer.shap()
82
- return [explainer.confidences, shap_img]
83
 
84
  ui = gr.Interface(fn=predict,
85
  inputs=gr.Image(type="pil"),
86
- outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")],
87
  examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
88
  ui.launch(share=True)
 
1
  import PIL
2
+ from captum.attr import GradientShap, Occlusion, LayerGradCam, LayerAttribution, IntegratedGradients
3
  from captum.attr import visualization as viz
4
  import torch
5
  from torchvision import transforms
 
65
  fig.suptitle(self.fig_title, fontsize=12)
66
  return self.convert_fig_to_pil(fig)
67
 
68
+ def occlusion(self):
69
+ occlusion = Occlusion(model)
70
+
71
+ attributions_occ = occlusion.attribute(self.input,
72
+ target=self.pred_label_idx,
73
+ strides=(3, 8, 8),
74
+ sliding_window_shapes=(3,15, 15),
75
+ baselines=0)
76
+
77
+ fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
78
+ np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
79
+ ["original_image", "heat_map", "heat_map", "masked_image"],
80
+ ["all", "positive", "negative", "positive"],
81
+ show_colorbar=True,
82
+ titles=["Original", "Positive Attribution", "Negative Attribution", "Masked"],
83
+ fig_size=(18, 6)
84
+ )
85
+ fig.suptitle(self.fig_title, fontsize=12)
86
+ return self.convert_fig_to_pil(fig)
87
+
88
+ def gradcam(self):
89
+ layer_gradcam = LayerGradCam(self.model, self.model.layer3[1].conv2)
90
+ attributions_lgc = layer_gradcam.attribute(self.input, target=self.pred_label_idx)
91
+
92
+ #_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(1,2,0).detach().numpy(),
93
+ # sign="all",
94
+ # title="Layer 3 Block 1 Conv 2")
95
+ upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, self.input.shape[2:])
96
+
97
+ fig, _ = viz.visualize_image_attr_multiple(upsamp_attr_lgc[0].cpu().permute(1,2,0).detach().numpy(),
98
+ self.transformed_img.permute(1,2,0).numpy(),
99
+ ["original_image","blended_heat_map","masked_image"],
100
+ ["all","positive","positive"],
101
+ show_colorbar=True,
102
+ titles=["Original", "Positive Attribution", "Masked"],
103
+ fig_size=(18, 6))
104
+ return self.convert_fig_to_pil(fig)
105
+
106
+ def integrated_gradients(self):
107
+ integrated_gradients = IntegratedGradients(self.model)
108
+ attributions_ig = integrated_gradients.attribute(self.input, target=self.pred_label_idx, n_steps=50)
109
+
110
+ fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
111
+ np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
112
+ ["original_image", "heat_map", "masked_image"],
113
+ ["all", "positive", "positive"],
114
+ show_colorbar=True,
115
+ titles=["Original", "Attribution", "Masked"],
116
+ fig_size=(18, 6))
117
+ fig.suptitle(self.fig_title, fontsize=12)
118
+ return self.convert_fig_to_pil(fig)
119
+
120
  def create_model_from_checkpoint():
121
  # Loads a model from a checkpoint
122
  model = resnet50()
 
130
 
131
  def predict(img):
132
  explainer = Explainer(model, img, labels)
133
+ return [explainer.confidences, explainer.shap(), explainer.occlusion(), explainer.gradcam(), explainer.integrated_gradients()]
 
134
 
135
  ui = gr.Interface(fn=predict,
136
  inputs=gr.Image(type="pil"),
137
+ outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
138
  examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
139
  ui.launch(share=True)