Spaces:
Runtime error
Runtime error
Hui
commited on
Commit
•
abb3b71
1
Parent(s):
0b1f893
change samples
Browse files- app.py +18 -16
- images/mask-sample.png +0 -0
- images/pe-sample.png +0 -0
app.py
CHANGED
@@ -47,7 +47,7 @@ pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]})
|
|
47 |
load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
|
48 |
|
49 |
|
50 |
-
def cnn(image):
|
51 |
# unsqueeze the input_tensor
|
52 |
input_tensor = transform(image)
|
53 |
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
|
@@ -60,7 +60,7 @@ def cnn(image):
|
|
60 |
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
61 |
|
62 |
|
63 |
-
def cnn_mask(image, last_phase):
|
64 |
# extract last phase
|
65 |
last_phase = int(last_phase.split("-")[0].strip())
|
66 |
# mask
|
@@ -85,7 +85,7 @@ def cnn_mask(image, last_phase):
|
|
85 |
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
86 |
|
87 |
|
88 |
-
def cnn_pe(image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
|
89 |
# form the position encoder vector
|
90 |
pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
|
91 |
# unsqueeze the input_tensor
|
@@ -103,16 +103,18 @@ with gr.Blocks() as demo:
|
|
103 |
gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
|
104 |
# inputs
|
105 |
with gr.Row():
|
106 |
-
|
|
|
|
|
107 |
# output
|
108 |
-
lable_output = gr.Label()
|
109 |
with gr.Tab("CNN") as cnn_tab:
|
110 |
cnn_button = gr.Button("Predict")
|
111 |
-
cnn_button.click(cnn, inputs=[image_input], outputs=[lable_output])
|
112 |
with gr.Tab("CNN+Mask") as mask_tab:
|
113 |
phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
|
114 |
mask_button = gr.Button("Predict")
|
115 |
-
mask_button.click(cnn_mask, inputs=[image_input, phase], outputs=[lable_output])
|
116 |
with gr.Tab("CNN+PE") as pe_tab:
|
117 |
with gr.Row():
|
118 |
p0 = gr.Number(label="Phase 0")
|
@@ -123,16 +125,16 @@ with gr.Blocks() as demo:
|
|
123 |
p5 = gr.Number(label="Phase 5")
|
124 |
p6 = gr.Number(label="Phase 6")
|
125 |
pe_button = gr.Button("Predict")
|
126 |
-
pe_button.click(cnn_pe, inputs=[image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
|
127 |
gr.Examples(
|
128 |
-
examples=[['images/
|
129 |
-
['images/
|
130 |
-
['images/clipping-cutting.png'],
|
131 |
-
['images/gallbladder-dissection.png'],
|
132 |
-
['images/gallbladder-packaging.png'],
|
133 |
-
['images/cleaning-coagulation.png'],
|
134 |
-
['images/gallbladder-retraction.png']],
|
135 |
-
inputs=image_input
|
136 |
)
|
137 |
|
138 |
if __name__ == "__main__":
|
47 |
load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
|
48 |
|
49 |
|
50 |
+
def cnn(label, image):
|
51 |
# unsqueeze the input_tensor
|
52 |
input_tensor = transform(image)
|
53 |
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
|
60 |
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
61 |
|
62 |
|
63 |
+
def cnn_mask(label, image, last_phase):
|
64 |
# extract last phase
|
65 |
last_phase = int(last_phase.split("-")[0].strip())
|
66 |
# mask
|
85 |
return {k: float(pred_softmax[v]) for k, v in classes.items()}
|
86 |
|
87 |
|
88 |
+
def cnn_pe(label, image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
|
89 |
# form the position encoder vector
|
90 |
pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
|
91 |
# unsqueeze the input_tensor
|
103 |
gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
|
104 |
# inputs
|
105 |
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
groundtruth_lable = gr.Text(label="Ground Truth", interactive=False)
|
108 |
+
image_input = gr.Image(shape=(255, 255), type="pil")
|
109 |
# output
|
110 |
+
lable_output = gr.Label(label="Result")
|
111 |
with gr.Tab("CNN") as cnn_tab:
|
112 |
cnn_button = gr.Button("Predict")
|
113 |
+
cnn_button.click(cnn, inputs=[groundtruth_lable, image_input], outputs=[lable_output])
|
114 |
with gr.Tab("CNN+Mask") as mask_tab:
|
115 |
phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
|
116 |
mask_button = gr.Button("Predict")
|
117 |
+
mask_button.click(cnn_mask, inputs=[groundtruth_lable, image_input, phase], outputs=[lable_output])
|
118 |
with gr.Tab("CNN+PE") as pe_tab:
|
119 |
with gr.Row():
|
120 |
p0 = gr.Number(label="Phase 0")
|
125 |
p5 = gr.Number(label="Phase 5")
|
126 |
p6 = gr.Number(label="Phase 6")
|
127 |
pe_button = gr.Button("Predict")
|
128 |
+
pe_button.click(cnn_pe, inputs=[groundtruth_lable, image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
|
129 |
gr.Examples(
|
130 |
+
examples=[['0 - Preparation', 'images/pe-sample.png'],
|
131 |
+
['1 - Calot Triangle Dissection', 'images/mask-sample.png'],
|
132 |
+
['2 - Clipping Cutting', 'images/clipping-cutting.png'],
|
133 |
+
['3 - Gallbladder Dissection', 'images/gallbladder-dissection.png'],
|
134 |
+
['4 - Gallbladder Packaging', 'images/gallbladder-packaging.png'],
|
135 |
+
['5 - Cleaning Coagulation', 'images/cleaning-coagulation.png'],
|
136 |
+
['6 - Gallbladder Retraction', 'images/gallbladder-retraction.png']],
|
137 |
+
inputs=[groundtruth_lable, image_input]
|
138 |
)
|
139 |
|
140 |
if __name__ == "__main__":
|
images/mask-sample.png
ADDED
images/pe-sample.png
ADDED