Spaces:
Runtime error
Runtime error
Hui
commited on
Commit
•
ea774f6
1
Parent(s):
6b2ab77
codes
Browse files- .gitignore +2 -0
- app.py +140 -0
- models.py +78 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
1 |
+
/.idea
|
2 |
+
/__pycache__
|
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from torch import nn
|
5 |
+
from torchvision import models, transforms
|
6 |
+
|
7 |
+
# parameters
|
8 |
+
from models import Cholec80Model
|
9 |
+
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
classes = {"Preparation": 0,
|
13 |
+
"Calot Triangle Dissection": 1,
|
14 |
+
"Clipping Cutting": 2,
|
15 |
+
"Gallbladder Dissection": 3,
|
16 |
+
"Gallbladder Packaging": 4,
|
17 |
+
"Cleaning Coagulation": 5,
|
18 |
+
"Gallbladder Retraction": 6}
|
19 |
+
|
20 |
+
# image transformations
|
21 |
+
mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104]
|
22 |
+
transform = transforms.Compose([transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=mean, std=std)])
|
24 |
+
|
25 |
+
|
26 |
+
# model imports
|
27 |
+
def load_pretrained_params(model, model_state_path: str):
|
28 |
+
pretrained_dict = torch.load(model_state_path, map_location="cpu")
|
29 |
+
model_dict = model.state_dict()
|
30 |
+
# 1. filter out unnecessary keys
|
31 |
+
if list(pretrained_dict.keys())[0].startswith("module."):
|
32 |
+
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
|
33 |
+
else:
|
34 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
35 |
+
# 2. overwrite entries in the existing state dict
|
36 |
+
model_dict.update(pretrained_dict)
|
37 |
+
# 3. load the new state dict
|
38 |
+
model.load_state_dict(model_dict)
|
39 |
+
# 4. eval mode
|
40 |
+
model.eval()
|
41 |
+
# 5. put model to device
|
42 |
+
model.to(device)
|
43 |
+
|
44 |
+
|
45 |
+
cnn_model = Cholec80Model({"image": [2048]})
|
46 |
+
load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt")
|
47 |
+
pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]})
|
48 |
+
load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
|
49 |
+
|
50 |
+
|
51 |
+
def cnn(image):
|
52 |
+
# unsqueeze the input_tensor
|
53 |
+
input_tensor = transform(image)
|
54 |
+
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
|
55 |
+
# predict
|
56 |
+
with torch.no_grad():
|
57 |
+
_, output_tensor = cnn_model(input_tensor, {})
|
58 |
+
# probabilities of all classes
|
59 |
+
pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0]
|
60 |
+
# return label dict
|
61 |
+
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
62 |
+
|
63 |
+
|
64 |
+
def cnn_mask(image, last_phase):
|
65 |
+
# extract last phase
|
66 |
+
last_phase = int(last_phase.split("-")[0].strip())
|
67 |
+
# mask
|
68 |
+
masks = [
|
69 |
+
[0, 0, -999, -999, -999, -999, -999],
|
70 |
+
[-999, 0, 0, -999, -999, -999, -999],
|
71 |
+
[-999, -999, 0, 0, -999, -999, -999],
|
72 |
+
[-999, -999, -999, 0, 0, 0, -999],
|
73 |
+
[-999, -999, -999, -999, 0, 0, 0],
|
74 |
+
[-999, -999, -999, -999, 0, 0, 0],
|
75 |
+
[-999, -999, -999, -999, -999, 0, 0]]
|
76 |
+
mask_tensor = torch.tensor([masks[last_phase]]).to(device)
|
77 |
+
# unsqueeze the input_tensor
|
78 |
+
input_tensor = transform(image)
|
79 |
+
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
|
80 |
+
# predict
|
81 |
+
with torch.no_grad():
|
82 |
+
_, output_tensor = cnn_model(input_tensor, {})
|
83 |
+
# probabilities of all classes
|
84 |
+
pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0]
|
85 |
+
# return label dict
|
86 |
+
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
87 |
+
|
88 |
+
|
89 |
+
def cnn_pe(image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
|
90 |
+
# form the position encoder vector
|
91 |
+
pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
|
92 |
+
# unsqueeze the input_tensor
|
93 |
+
input_tensor = transform(image)
|
94 |
+
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
|
95 |
+
# predict
|
96 |
+
with torch.no_grad():
|
97 |
+
_, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc})
|
98 |
+
pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0]
|
99 |
+
# return label dict
|
100 |
+
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
101 |
+
|
102 |
+
|
103 |
+
with gr.Blocks() as demo:
|
104 |
+
gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
|
105 |
+
# inputs
|
106 |
+
with gr.Row():
|
107 |
+
image_input = gr.Image(shape=(255, 255), type="pil")
|
108 |
+
# output
|
109 |
+
lable_output = gr.Label()
|
110 |
+
with gr.Tab("CNN") as cnn_tab:
|
111 |
+
cnn_button = gr.Button("Predict")
|
112 |
+
cnn_button.click(cnn, inputs=[image_input], outputs=[lable_output])
|
113 |
+
with gr.Tab("CNN+Mask") as mask_tab:
|
114 |
+
phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
|
115 |
+
mask_button = gr.Button("Predict")
|
116 |
+
mask_button.click(cnn_mask, inputs=[image_input, phase], outputs=[lable_output])
|
117 |
+
with gr.Tab("CNN+PE") as pe_tab:
|
118 |
+
with gr.Row():
|
119 |
+
p0 = gr.Number(label="Phase 0")
|
120 |
+
p1 = gr.Number(label="Phase 1")
|
121 |
+
p2 = gr.Number(label="Phase 2")
|
122 |
+
p3 = gr.Number(label="Phase 3")
|
123 |
+
p4 = gr.Number(label="Phase 4")
|
124 |
+
p5 = gr.Number(label="Phase 5")
|
125 |
+
p6 = gr.Number(label="Phase 6")
|
126 |
+
pe_button = gr.Button("Predict")
|
127 |
+
pe_button.click(cnn_pe, inputs=[image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
|
128 |
+
gr.Examples(
|
129 |
+
examples=[['images/preparation.png'],
|
130 |
+
['images/calot-triangle-dissection.png'],
|
131 |
+
['images/clipping-cutting.png'],
|
132 |
+
['images/gallbladder-dissection.png'],
|
133 |
+
['images/gallbladder-packaging.png'],
|
134 |
+
['images/cleaning-coagulation.png'],
|
135 |
+
['images/gallbladder-retraction.png']],
|
136 |
+
inputs=image_input
|
137 |
+
)
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
demo.launch(share=True)
|
models.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import nn
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
def get_linear_layers(dimensions):
|
10 |
+
init_dim = dimensions[0]
|
11 |
+
dimensions = dimensions[1:]
|
12 |
+
if len(dimensions) < 1:
|
13 |
+
return []
|
14 |
+
layers = []
|
15 |
+
tmp_dim = init_dim
|
16 |
+
for i, d in enumerate(dimensions[:-1]):
|
17 |
+
layers.append((f"linear{i + 1}", nn.Linear(tmp_dim, d)))
|
18 |
+
layers.append((f"active{i + 1}", nn.ReLU()))
|
19 |
+
tmp_dim = d
|
20 |
+
layers.append((f"linear{len(dimensions)}", nn.Linear(tmp_dim, dimensions[-1])))
|
21 |
+
return layers
|
22 |
+
|
23 |
+
|
24 |
+
def num_flat_features(x):
|
25 |
+
size = x.size()[1:]
|
26 |
+
num_features = 1
|
27 |
+
for s in size:
|
28 |
+
num_features *= s
|
29 |
+
return num_features
|
30 |
+
|
31 |
+
|
32 |
+
class Cholec80Model(nn.Module):
|
33 |
+
def __init__(self, dimensions):
|
34 |
+
super(Cholec80Model, self).__init__()
|
35 |
+
# hyperparams
|
36 |
+
self.dimensions = dimensions
|
37 |
+
# CNN models
|
38 |
+
if "image" in self.dimensions:
|
39 |
+
self.model = models.resnet50(pretrained=True)
|
40 |
+
self.model.fc = nn.Identity()
|
41 |
+
# get img submodel
|
42 |
+
self.submodels = {}
|
43 |
+
# get info submodels
|
44 |
+
for key in self.dimensions.keys():
|
45 |
+
self.submodels[key] = nn.Sequential(OrderedDict(get_linear_layers(self.dimensions[key])))
|
46 |
+
# !!!register submodels to model
|
47 |
+
for key in self.submodels:
|
48 |
+
self.add_module(key, self.submodels[key])
|
49 |
+
# concat layers
|
50 |
+
dim_concat = 0
|
51 |
+
for key, ds in self.dimensions.items():
|
52 |
+
out_dim = ds[-1]
|
53 |
+
dim_concat += out_dim
|
54 |
+
self.last_layer = nn.Sequential(
|
55 |
+
nn.Linear(dim_concat, 7),
|
56 |
+
nn.LogSigmoid()
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, img_tensor, info_tensors):
|
60 |
+
concat_tensor = None
|
61 |
+
# image feature extraction
|
62 |
+
if "image" in self.dimensions:
|
63 |
+
out_feature = self.model(img_tensor)
|
64 |
+
concat_tensor = out_feature.clone()
|
65 |
+
concat_tensor = self.submodels["image"](concat_tensor)
|
66 |
+
concat_tensor = concat_tensor.view(-1, num_flat_features(concat_tensor))
|
67 |
+
# concat image_tensor with other info_tensors
|
68 |
+
for key, t in info_tensors.items():
|
69 |
+
t = self.submodels[key](t)
|
70 |
+
t = t.view(-1, num_flat_features(t))
|
71 |
+
if concat_tensor is None:
|
72 |
+
concat_tensor = t
|
73 |
+
else:
|
74 |
+
concat_tensor = torch.cat((concat_tensor, t), dim=1)
|
75 |
+
# last_layer
|
76 |
+
out_tensor = self.last_layer(concat_tensor)
|
77 |
+
# return results
|
78 |
+
return img_tensor, out_tensor
|