Hui commited on
Commit
ea774f6
1 Parent(s): 6b2ab77
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +140 -0
  3. models.py +78 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
1
+ /.idea
2
+ /__pycache__
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torch import nn
5
+ from torchvision import models, transforms
6
+
7
+ # parameters
8
+ from models import Cholec80Model
9
+
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ classes = {"Preparation": 0,
13
+ "Calot Triangle Dissection": 1,
14
+ "Clipping Cutting": 2,
15
+ "Gallbladder Dissection": 3,
16
+ "Gallbladder Packaging": 4,
17
+ "Cleaning Coagulation": 5,
18
+ "Gallbladder Retraction": 6}
19
+
20
+ # image transformations
21
+ mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104]
22
+ transform = transforms.Compose([transforms.ToTensor(),
23
+ transforms.Normalize(mean=mean, std=std)])
24
+
25
+
26
+ # model imports
27
+ def load_pretrained_params(model, model_state_path: str):
28
+ pretrained_dict = torch.load(model_state_path, map_location="cpu")
29
+ model_dict = model.state_dict()
30
+ # 1. filter out unnecessary keys
31
+ if list(pretrained_dict.keys())[0].startswith("module."):
32
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
33
+ else:
34
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
35
+ # 2. overwrite entries in the existing state dict
36
+ model_dict.update(pretrained_dict)
37
+ # 3. load the new state dict
38
+ model.load_state_dict(model_dict)
39
+ # 4. eval mode
40
+ model.eval()
41
+ # 5. put model to device
42
+ model.to(device)
43
+
44
+
45
+ cnn_model = Cholec80Model({"image": [2048]})
46
+ load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt")
47
+ pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]})
48
+ load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
49
+
50
+
51
+ def cnn(image):
52
+ # unsqueeze the input_tensor
53
+ input_tensor = transform(image)
54
+ input_tensor = input_tensor.unsqueeze(dim=0).to(device)
55
+ # predict
56
+ with torch.no_grad():
57
+ _, output_tensor = cnn_model(input_tensor, {})
58
+ # probabilities of all classes
59
+ pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0]
60
+ # return label dict
61
+ return {k: float(pred_softmax[v]) for k, v in classes.items()}
62
+
63
+
64
+ def cnn_mask(image, last_phase):
65
+ # extract last phase
66
+ last_phase = int(last_phase.split("-")[0].strip())
67
+ # mask
68
+ masks = [
69
+ [0, 0, -999, -999, -999, -999, -999],
70
+ [-999, 0, 0, -999, -999, -999, -999],
71
+ [-999, -999, 0, 0, -999, -999, -999],
72
+ [-999, -999, -999, 0, 0, 0, -999],
73
+ [-999, -999, -999, -999, 0, 0, 0],
74
+ [-999, -999, -999, -999, 0, 0, 0],
75
+ [-999, -999, -999, -999, -999, 0, 0]]
76
+ mask_tensor = torch.tensor([masks[last_phase]]).to(device)
77
+ # unsqueeze the input_tensor
78
+ input_tensor = transform(image)
79
+ input_tensor = input_tensor.unsqueeze(dim=0).to(device)
80
+ # predict
81
+ with torch.no_grad():
82
+ _, output_tensor = cnn_model(input_tensor, {})
83
+ # probabilities of all classes
84
+ pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0]
85
+ # return label dict
86
+ return {k: float(pred_softmax[v]) for k, v in classes.items()}
87
+
88
+
89
+ def cnn_pe(image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
90
+ # form the position encoder vector
91
+ pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
92
+ # unsqueeze the input_tensor
93
+ input_tensor = transform(image)
94
+ input_tensor = input_tensor.unsqueeze(dim=0).to(device)
95
+ # predict
96
+ with torch.no_grad():
97
+ _, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc})
98
+ pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0]
99
+ # return label dict
100
+ return {k: float(pred_softmax[v]) for k, v in classes.items()}
101
+
102
+
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
105
+ # inputs
106
+ with gr.Row():
107
+ image_input = gr.Image(shape=(255, 255), type="pil")
108
+ # output
109
+ lable_output = gr.Label()
110
+ with gr.Tab("CNN") as cnn_tab:
111
+ cnn_button = gr.Button("Predict")
112
+ cnn_button.click(cnn, inputs=[image_input], outputs=[lable_output])
113
+ with gr.Tab("CNN+Mask") as mask_tab:
114
+ phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
115
+ mask_button = gr.Button("Predict")
116
+ mask_button.click(cnn_mask, inputs=[image_input, phase], outputs=[lable_output])
117
+ with gr.Tab("CNN+PE") as pe_tab:
118
+ with gr.Row():
119
+ p0 = gr.Number(label="Phase 0")
120
+ p1 = gr.Number(label="Phase 1")
121
+ p2 = gr.Number(label="Phase 2")
122
+ p3 = gr.Number(label="Phase 3")
123
+ p4 = gr.Number(label="Phase 4")
124
+ p5 = gr.Number(label="Phase 5")
125
+ p6 = gr.Number(label="Phase 6")
126
+ pe_button = gr.Button("Predict")
127
+ pe_button.click(cnn_pe, inputs=[image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
128
+ gr.Examples(
129
+ examples=[['images/preparation.png'],
130
+ ['images/calot-triangle-dissection.png'],
131
+ ['images/clipping-cutting.png'],
132
+ ['images/gallbladder-dissection.png'],
133
+ ['images/gallbladder-packaging.png'],
134
+ ['images/cleaning-coagulation.png'],
135
+ ['images/gallbladder-retraction.png']],
136
+ inputs=image_input
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch(share=True)
models.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torchvision.models as models
5
+ from torch import nn
6
+ from collections import OrderedDict
7
+
8
+
9
+ def get_linear_layers(dimensions):
10
+ init_dim = dimensions[0]
11
+ dimensions = dimensions[1:]
12
+ if len(dimensions) < 1:
13
+ return []
14
+ layers = []
15
+ tmp_dim = init_dim
16
+ for i, d in enumerate(dimensions[:-1]):
17
+ layers.append((f"linear{i + 1}", nn.Linear(tmp_dim, d)))
18
+ layers.append((f"active{i + 1}", nn.ReLU()))
19
+ tmp_dim = d
20
+ layers.append((f"linear{len(dimensions)}", nn.Linear(tmp_dim, dimensions[-1])))
21
+ return layers
22
+
23
+
24
+ def num_flat_features(x):
25
+ size = x.size()[1:]
26
+ num_features = 1
27
+ for s in size:
28
+ num_features *= s
29
+ return num_features
30
+
31
+
32
+ class Cholec80Model(nn.Module):
33
+ def __init__(self, dimensions):
34
+ super(Cholec80Model, self).__init__()
35
+ # hyperparams
36
+ self.dimensions = dimensions
37
+ # CNN models
38
+ if "image" in self.dimensions:
39
+ self.model = models.resnet50(pretrained=True)
40
+ self.model.fc = nn.Identity()
41
+ # get img submodel
42
+ self.submodels = {}
43
+ # get info submodels
44
+ for key in self.dimensions.keys():
45
+ self.submodels[key] = nn.Sequential(OrderedDict(get_linear_layers(self.dimensions[key])))
46
+ # !!!register submodels to model
47
+ for key in self.submodels:
48
+ self.add_module(key, self.submodels[key])
49
+ # concat layers
50
+ dim_concat = 0
51
+ for key, ds in self.dimensions.items():
52
+ out_dim = ds[-1]
53
+ dim_concat += out_dim
54
+ self.last_layer = nn.Sequential(
55
+ nn.Linear(dim_concat, 7),
56
+ nn.LogSigmoid()
57
+ )
58
+
59
+ def forward(self, img_tensor, info_tensors):
60
+ concat_tensor = None
61
+ # image feature extraction
62
+ if "image" in self.dimensions:
63
+ out_feature = self.model(img_tensor)
64
+ concat_tensor = out_feature.clone()
65
+ concat_tensor = self.submodels["image"](concat_tensor)
66
+ concat_tensor = concat_tensor.view(-1, num_flat_features(concat_tensor))
67
+ # concat image_tensor with other info_tensors
68
+ for key, t in info_tensors.items():
69
+ t = self.submodels[key](t)
70
+ t = t.view(-1, num_flat_features(t))
71
+ if concat_tensor is None:
72
+ concat_tensor = t
73
+ else:
74
+ concat_tensor = torch.cat((concat_tensor, t), dim=1)
75
+ # last_layer
76
+ out_tensor = self.last_layer(concat_tensor)
77
+ # return results
78
+ return img_tensor, out_tensor