Federico Galatolo commited on
Commit
fa81659
1 Parent(s): 35a06f8

gradcam working on cv image

Browse files
app.py CHANGED
@@ -197,7 +197,7 @@ def explain(img, model):
197
 
198
  state.write(f"Populating Gradcam++ for lesion #{i}...")
199
  st.subheader("Gradcam++")
200
- fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
201
  st.pyplot(fig)
202
 
203
  state.write("All done...")
@@ -207,7 +207,7 @@ MODEL = "./models/model.pth"
207
  PCA_MODEL = "./models/pca.pkl"
208
  FEATURES_DATABASE = "./assets/features/features.json"
209
 
210
- st.header("Explainable oral lesion detection")
211
  st.markdown("""Demo for the paper [Explainable diagnosis of oral cancer via deep learning and case-based reasoning](https://mlpi.ing.unipi.it/doctoralai/)
212
 
213
  Upload an image using the form below and click on "Process"
 
197
 
198
  state.write(f"Populating Gradcam++ for lesion #{i}...")
199
  st.subheader("Gradcam++")
200
+ fig = plot_gradcam(model=MODEL, img=img, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
201
  st.pyplot(fig)
202
 
203
  state.write("All done...")
 
207
  PCA_MODEL = "./models/pca.pkl"
208
  FEATURES_DATABASE = "./assets/features/features.json"
209
 
210
+ st.header("Explainable Oral Lesion Detection")
211
  st.markdown("""Demo for the paper [Explainable diagnosis of oral cancer via deep learning and case-based reasoning](https://mlpi.ing.unipi.it/doctoralai/)
212
 
213
  Upload an image using the form below and click on "Process"
plots/gradcam/detectron2_gradcam.py CHANGED
@@ -86,7 +86,7 @@ class Detectron2GradCAM():
86
  checkpointer = DetectionCheckpointer(model)
87
  checkpointer.load(self.cfg.MODEL.WEIGHTS)
88
 
89
- image = read_image(img, format="BGR")
90
  input_image_dict = self._get_input_dict(image)
91
 
92
  if grad_cam_type == "GradCAM":
 
86
  checkpointer = DetectionCheckpointer(model)
87
  checkpointer.load(self.cfg.MODEL.WEIGHTS)
88
 
89
+ image = img
90
  input_image_dict = self._get_input_dict(image)
91
 
92
  if grad_cam_type == "GradCAM":
plots/plot_gradcam.py CHANGED
@@ -31,7 +31,7 @@ def plot_gradcam(**kwargs):
31
 
32
 
33
  cam_extractor = Detectron2GradCAM(config_file, cfg_list)
34
- image_dict, cam_orig = cam_extractor.get_cam(img=kwargs.file, target_instance=kwargs.instance, layer_name=kwargs.layer, grad_cam_type="GradCAM++")
35
 
36
  with torch.no_grad():
37
  fig = plt.figure(figsize=(kwargs.fig_h/kwargs.fig_dpi, kwargs.fig_w/kwargs.fig_dpi), dpi=kwargs.fig_dpi)
 
31
 
32
 
33
  cam_extractor = Detectron2GradCAM(config_file, cfg_list)
34
+ image_dict, cam_orig = cam_extractor.get_cam(img=kwargs.img, target_instance=kwargs.instance, layer_name=kwargs.layer, grad_cam_type="GradCAM++")
35
 
36
  with torch.no_grad():
37
  fig = plt.figure(figsize=(kwargs.fig_h/kwargs.fig_dpi, kwargs.fig_w/kwargs.fig_dpi), dpi=kwargs.fig_dpi)