Tobias Czempiel
commited on
Commit
β’
9128cc2
1
Parent(s):
0994b69
workflow stuff
Browse files
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
checkpoints/{ham10k_checkpoint_mobile_0.82_epoch24.pt β state_dict_timm_effnet_b3_e6_val_f1=0.75.pt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e3ba0141dacacd6a79a6073f5e6bbb6daa3f132944635b54adb7c94f2761e99
|
3 |
+
size 43381701
|
images/video01_000014_prep.png
ADDED
images/video01_001403.png
ADDED
images/video01_001528_pack.png
ADDED
runSDSdemo.py
CHANGED
@@ -8,18 +8,36 @@ import torchvision.transforms as transforms
|
|
8 |
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
9 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
10 |
import gradio as gr
|
|
|
11 |
|
12 |
# model setup
|
13 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
-
classes = [
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
#state_dict_trained = torch.hub.load_state_dict_from_url("https://github.com/tobiascz/demotime/raw/main/checkpoints/ham10k_checkpoint_mobile_0.82_epoch24.pt", model_dir=".", map_location = device)
|
19 |
import os
|
20 |
print(os.getcwd())
|
21 |
-
state_dict_trained = torch.load('checkpoints/
|
22 |
-
model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
model.eval() # This
|
24 |
|
25 |
# image pre-processing
|
@@ -76,21 +94,18 @@ def predict(input_img):
|
|
76 |
# class with hightest probability
|
77 |
pred = torch.argmax(outputs, dim=1).cpu().numpy()
|
78 |
# diagnostic suggestions
|
79 |
-
|
80 |
-
suggestion = "CHECK WITH YOUR MD!"
|
81 |
-
else:
|
82 |
-
suggestion = "Nothing to be worried about."
|
83 |
# grad_cam image
|
84 |
target_layers = model.features[-1]
|
85 |
output_img = image_grad_cam(model,leasion_tensor,input_float_np,target_layers)
|
86 |
# return label dict and suggestion
|
87 |
-
return {classes[i]: float(pred_softmax[i]) for i in range(len(classes))},
|
88 |
|
89 |
# start gradio application
|
90 |
gr.Interface(
|
91 |
fn=predict,
|
92 |
inputs=gr.inputs.Image(),
|
93 |
-
outputs=[gr.outputs.Label(label="Predict Result"), gr.outputs.
|
94 |
-
examples=[['
|
95 |
-
title="
|
96 |
).launch()
|
|
|
8 |
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
9 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
10 |
import gradio as gr
|
11 |
+
import timm
|
12 |
|
13 |
# model setup
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
classes = ["Preparation",
|
16 |
+
"CalotTriangleDissection",
|
17 |
+
"ClippingCutting",
|
18 |
+
"GallbladderDissection",
|
19 |
+
"GallbladderPackaging",
|
20 |
+
"CleaningCoagulation",
|
21 |
+
"GallbladderRetraction"]
|
22 |
+
|
23 |
+
|
24 |
+
model = timm.create_model('efficientnet_b3')# This is a very well known network but it is designed for 1000 classes and not just cats and dogs this is why we need the next line
|
25 |
+
model.classifier = nn.Linear(1536, 7)
|
26 |
#state_dict_trained = torch.hub.load_state_dict_from_url("https://github.com/tobiascz/demotime/raw/main/checkpoints/ham10k_checkpoint_mobile_0.82_epoch24.pt", model_dir=".", map_location = device)
|
27 |
import os
|
28 |
print(os.getcwd())
|
29 |
+
state_dict_trained = torch.load('checkpoints/state_dict_timm_effnet_b3_e6_val_f1=0.75.pt', map_location=torch.device(device))
|
30 |
+
sd = model.state_dict()
|
31 |
+
|
32 |
+
print(state_dict_trained.keys())
|
33 |
+
|
34 |
+
for k,v in sd.items():
|
35 |
+
if not "classifier" in k:
|
36 |
+
sd[k] = state_dict_trained[f'model.model.{k}']
|
37 |
+
sd['classifier.weight'] = state_dict_trained['model.fc_phase.weight']
|
38 |
+
sd['classifier.bias'] = state_dict_trained['model.fc_phase.bias']
|
39 |
+
|
40 |
+
model.load_state_dict(sd) ## Here we load the trained weights (state_dict) in our model
|
41 |
model.eval() # This
|
42 |
|
43 |
# image pre-processing
|
|
|
94 |
# class with hightest probability
|
95 |
pred = torch.argmax(outputs, dim=1).cpu().numpy()
|
96 |
# diagnostic suggestions
|
97 |
+
|
|
|
|
|
|
|
98 |
# grad_cam image
|
99 |
target_layers = model.features[-1]
|
100 |
output_img = image_grad_cam(model,leasion_tensor,input_float_np,target_layers)
|
101 |
# return label dict and suggestion
|
102 |
+
return {classes[i]: float(pred_softmax[i]) for i in range(len(classes))}, output_img
|
103 |
|
104 |
# start gradio application
|
105 |
gr.Interface(
|
106 |
fn=predict,
|
107 |
inputs=gr.inputs.Image(),
|
108 |
+
outputs=[gr.outputs.Label(label="Predict Result"), gr.outputs.Image(label="GradCAM")],
|
109 |
+
examples=[['images/video01_000014_prep.png'],['images/video01_001403.png'],['images/video01_001528_pack.png']],
|
110 |
+
title="Surgical Workflow Classifier"
|
111 |
).launch()
|