Federico Galatolo commited on
Commit
c6268ab
1 Parent(s): 078b492

force use CPU

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -87,6 +87,7 @@ def load_model():
87
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
88
  cfg.MODEL.WEIGHTS = MODEL
89
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH
 
90
 
91
 
92
  metadata = Metadata()
@@ -138,7 +139,10 @@ def draw_box(file_name, box, type, model, resize_input=False):
138
 
139
 
140
  def explain(img, model):
 
141
  database = json.load(open(FEATURES_DATABASE))
 
 
142
  instances, input = forward_model_full(model["model"], model["cfg"], img)
143
 
144
  instances.remove("pred_masks")
@@ -150,23 +154,28 @@ def explain(img, model):
150
  pred = cv2.resize(pred, (800, 800))
151
  pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
152
 
153
- tabs = st.tabs(["Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))])
154
- lesion_tabs = tabs[1:]
155
 
156
  with tabs[0]:
 
 
 
 
 
157
  st.header("Detected lesions")
158
- state.text("All done...")
159
- tooltip.success("Use the tabs for a detailed explanation of each lesion")
160
  st.image(pred)
161
 
162
 
163
  for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)):
 
164
  healthy_prob = scores[-1].item()
165
  scores = scores[:-1]
166
  features = features.tolist()
167
 
168
  with tab:
169
  st.header(f"Lesion #{i}")
 
170
  lesion_img = draw_box(img, box.cpu(), type, model)
171
  lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB)
172
 
@@ -190,16 +199,21 @@ def explain(img, model):
190
  st.subheader("Feature space")
191
  col1, col2 = st.columns(2)
192
 
 
193
  fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100)
194
  col1.pyplot(fig)
195
-
 
196
  fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100)
197
  col2.pyplot(fig)
198
 
 
199
  st.subheader("Gradcam++")
200
  fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
201
  st.pyplot(fig)
202
 
 
 
203
  FILE = "./test.jpg"
204
  MODEL = "./models/model.pth"
205
  PCA_MODEL = "./models/pca.pkl"
 
87
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
88
  cfg.MODEL.WEIGHTS = MODEL
89
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH
90
+ cfg.MODEL.DEVICE = "cpu"
91
 
92
 
93
  metadata = Metadata()
 
139
 
140
 
141
  def explain(img, model):
142
+ state.write("Loading features...")
143
  database = json.load(open(FEATURES_DATABASE))
144
+
145
+ state.write("Computing logits...")
146
  instances, input = forward_model_full(model["model"], model["cfg"], img)
147
 
148
  instances.remove("pred_masks")
 
154
  pred = cv2.resize(pred, (800, 800))
155
  pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
156
 
157
+ tabs = st.tabs(["Result", "Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))])
158
+ lesion_tabs = tabs[2:]
159
 
160
  with tabs[0]:
161
+ st.header("Image processed")
162
+ st.success("Use the tabs on the right to see the detected lesions and detailed explanations for each lesion")
163
+
164
+ state.write("Populating first tab...")
165
+ with tabs[1]:
166
  st.header("Detected lesions")
 
 
167
  st.image(pred)
168
 
169
 
170
  for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)):
171
+ state.write(f"Populating tab for lesion #{i}...")
172
  healthy_prob = scores[-1].item()
173
  scores = scores[:-1]
174
  features = features.tolist()
175
 
176
  with tab:
177
  st.header(f"Lesion #{i}")
178
+ state.write(f"Populating classes for lesion #{i}...")
179
  lesion_img = draw_box(img, box.cpu(), type, model)
180
  lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB)
181
 
 
199
  st.subheader("Feature space")
200
  col1, col2 = st.columns(2)
201
 
202
+ state.write(f"Populating PCA for lesion #{i}...")
203
  fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100)
204
  col1.pyplot(fig)
205
+
206
+ state.write(f"Populating histogram for lesion #{i}...")
207
  fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100)
208
  col2.pyplot(fig)
209
 
210
+ state.write(f"Populating Gradcam++ for lesion #{i}...")
211
  st.subheader("Gradcam++")
212
  fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
213
  st.pyplot(fig)
214
 
215
+ state.write("All done...")
216
+
217
  FILE = "./test.jpg"
218
  MODEL = "./models/model.pth"
219
  PCA_MODEL = "./models/pca.pkl"