Hui commited on
Commit
abb3b71
1 Parent(s): 0b1f893

change samples

Browse files
Files changed (3) hide show
  1. app.py +18 -16
  2. images/mask-sample.png +0 -0
  3. images/pe-sample.png +0 -0
app.py CHANGED
@@ -47,7 +47,7 @@ pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]})
47
  load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
48
 
49
 
50
- def cnn(image):
51
  # unsqueeze the input_tensor
52
  input_tensor = transform(image)
53
  input_tensor = input_tensor.unsqueeze(dim=0).to(device)
@@ -60,7 +60,7 @@ def cnn(image):
60
  return {k: float(pred_softmax[v]) for k, v in classes.items()}
61
 
62
 
63
- def cnn_mask(image, last_phase):
64
  # extract last phase
65
  last_phase = int(last_phase.split("-")[0].strip())
66
  # mask
@@ -85,7 +85,7 @@ def cnn_mask(image, last_phase):
85
  return {k: float(pred_softmax[v]) for k, v in classes.items()}
86
 
87
 
88
- def cnn_pe(image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
89
  # form the position encoder vector
90
  pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
91
  # unsqueeze the input_tensor
@@ -103,16 +103,18 @@ with gr.Blocks() as demo:
103
  gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
104
  # inputs
105
  with gr.Row():
106
- image_input = gr.Image(shape=(255, 255), type="pil")
 
 
107
  # output
108
- lable_output = gr.Label()
109
  with gr.Tab("CNN") as cnn_tab:
110
  cnn_button = gr.Button("Predict")
111
- cnn_button.click(cnn, inputs=[image_input], outputs=[lable_output])
112
  with gr.Tab("CNN+Mask") as mask_tab:
113
  phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
114
  mask_button = gr.Button("Predict")
115
- mask_button.click(cnn_mask, inputs=[image_input, phase], outputs=[lable_output])
116
  with gr.Tab("CNN+PE") as pe_tab:
117
  with gr.Row():
118
  p0 = gr.Number(label="Phase 0")
@@ -123,16 +125,16 @@ with gr.Blocks() as demo:
123
  p5 = gr.Number(label="Phase 5")
124
  p6 = gr.Number(label="Phase 6")
125
  pe_button = gr.Button("Predict")
126
- pe_button.click(cnn_pe, inputs=[image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
127
  gr.Examples(
128
- examples=[['images/preparation.png'],
129
- ['images/calot-triangle-dissection.png'],
130
- ['images/clipping-cutting.png'],
131
- ['images/gallbladder-dissection.png'],
132
- ['images/gallbladder-packaging.png'],
133
- ['images/cleaning-coagulation.png'],
134
- ['images/gallbladder-retraction.png']],
135
- inputs=image_input
136
  )
137
 
138
  if __name__ == "__main__":
47
  load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
48
 
49
 
50
+ def cnn(label, image):
51
  # unsqueeze the input_tensor
52
  input_tensor = transform(image)
53
  input_tensor = input_tensor.unsqueeze(dim=0).to(device)
60
  return {k: float(pred_softmax[v]) for k, v in classes.items()}
61
 
62
 
63
+ def cnn_mask(label, image, last_phase):
64
  # extract last phase
65
  last_phase = int(last_phase.split("-")[0].strip())
66
  # mask
85
  return {k: float(pred_softmax[v]) for k, v in classes.items()}
86
 
87
 
88
+ def cnn_pe(label, image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
89
  # form the position encoder vector
90
  pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
91
  # unsqueeze the input_tensor
103
  gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
104
  # inputs
105
  with gr.Row():
106
+ with gr.Column():
107
+ groundtruth_lable = gr.Text(label="Ground Truth", interactive=False)
108
+ image_input = gr.Image(shape=(255, 255), type="pil")
109
  # output
110
+ lable_output = gr.Label(label="Result")
111
  with gr.Tab("CNN") as cnn_tab:
112
  cnn_button = gr.Button("Predict")
113
+ cnn_button.click(cnn, inputs=[groundtruth_lable, image_input], outputs=[lable_output])
114
  with gr.Tab("CNN+Mask") as mask_tab:
115
  phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
116
  mask_button = gr.Button("Predict")
117
+ mask_button.click(cnn_mask, inputs=[groundtruth_lable, image_input, phase], outputs=[lable_output])
118
  with gr.Tab("CNN+PE") as pe_tab:
119
  with gr.Row():
120
  p0 = gr.Number(label="Phase 0")
125
  p5 = gr.Number(label="Phase 5")
126
  p6 = gr.Number(label="Phase 6")
127
  pe_button = gr.Button("Predict")
128
+ pe_button.click(cnn_pe, inputs=[groundtruth_lable, image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
129
  gr.Examples(
130
+ examples=[['0 - Preparation', 'images/pe-sample.png'],
131
+ ['1 - Calot Triangle Dissection', 'images/mask-sample.png'],
132
+ ['2 - Clipping Cutting', 'images/clipping-cutting.png'],
133
+ ['3 - Gallbladder Dissection', 'images/gallbladder-dissection.png'],
134
+ ['4 - Gallbladder Packaging', 'images/gallbladder-packaging.png'],
135
+ ['5 - Cleaning Coagulation', 'images/cleaning-coagulation.png'],
136
+ ['6 - Gallbladder Retraction', 'images/gallbladder-retraction.png']],
137
+ inputs=[groundtruth_lable, image_input]
138
  )
139
 
140
  if __name__ == "__main__":
images/mask-sample.png ADDED
images/pe-sample.png ADDED