Hui
change samples
abb3b71
import gradio as gr
import torch
from torchvision import transforms
# parameters
from models import Cholec80Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = {"Preparation": 0,
"Calot Triangle Dissection": 1,
"Clipping Cutting": 2,
"Gallbladder Dissection": 3,
"Gallbladder Packaging": 4,
"Cleaning Coagulation": 5,
"Gallbladder Retraction": 6}
# image transformations
mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)])
# model imports
def load_pretrained_params(model, model_state_path: str):
checkpoint = torch.load(model_state_path, map_location="cpu")
pretrained_dict = checkpoint["state_dict"]
model_dict = model.state_dict()
# 1. filter out unnecessary keys
if list(pretrained_dict.keys())[0].startswith("model."):
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
else:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
# 4. eval mode
model.eval()
# 5. put model to device
model.to(device)
cnn_model = Cholec80Model({"image": [2048]})
load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt")
pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]})
load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt")
def cnn(label, image):
# unsqueeze the input_tensor
input_tensor = transform(image)
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
# predict
with torch.no_grad():
_, output_tensor = cnn_model(input_tensor, {})
# probabilities of all classes
pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0]
# return label dict
return {k: float(pred_softmax[v]) for k, v in classes.items()}
def cnn_mask(label, image, last_phase):
# extract last phase
last_phase = int(last_phase.split("-")[0].strip())
# mask
masks = [
[0, 0, -999, -999, -999, -999, -999],
[-999, 0, 0, -999, -999, -999, -999],
[-999, -999, 0, 0, -999, -999, -999],
[-999, -999, -999, 0, 0, 0, -999],
[-999, -999, -999, -999, 0, 0, 0],
[-999, -999, -999, -999, 0, 0, 0],
[-999, -999, -999, -999, -999, 0, 0]]
mask_tensor = torch.tensor([masks[last_phase]]).to(device)
# unsqueeze the input_tensor
input_tensor = transform(image)
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
# predict
with torch.no_grad():
_, output_tensor = cnn_model(input_tensor, {})
# probabilities of all classes
pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0]
# return label dict
return {k: float(pred_softmax[v]) for k, v in classes.items()}
def cnn_pe(label, image, p_0, p_1, p_2, p_3, p_4, p_5, p_6):
# form the position encoder vector
pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device)
# unsqueeze the input_tensor
input_tensor = transform(image)
input_tensor = input_tensor.unsqueeze(dim=0).to(device)
# predict
with torch.no_grad():
_, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc})
pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0]
# return label dict
return {k: float(pred_softmax[v]) for k, v in classes.items()}
with gr.Blocks() as demo:
gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries")
# inputs
with gr.Row():
with gr.Column():
groundtruth_lable = gr.Text(label="Ground Truth", interactive=False)
image_input = gr.Image(shape=(255, 255), type="pil")
# output
lable_output = gr.Label(label="Result")
with gr.Tab("CNN") as cnn_tab:
cnn_button = gr.Button("Predict")
cnn_button.click(cnn, inputs=[groundtruth_lable, image_input], outputs=[lable_output])
with gr.Tab("CNN+Mask") as mask_tab:
phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase")
mask_button = gr.Button("Predict")
mask_button.click(cnn_mask, inputs=[groundtruth_lable, image_input, phase], outputs=[lable_output])
with gr.Tab("CNN+PE") as pe_tab:
with gr.Row():
p0 = gr.Number(label="Phase 0")
p1 = gr.Number(label="Phase 1")
p2 = gr.Number(label="Phase 2")
p3 = gr.Number(label="Phase 3")
p4 = gr.Number(label="Phase 4")
p5 = gr.Number(label="Phase 5")
p6 = gr.Number(label="Phase 6")
pe_button = gr.Button("Predict")
pe_button.click(cnn_pe, inputs=[groundtruth_lable, image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output])
gr.Examples(
examples=[['0 - Preparation', 'images/pe-sample.png'],
['1 - Calot Triangle Dissection', 'images/mask-sample.png'],
['2 - Clipping Cutting', 'images/clipping-cutting.png'],
['3 - Gallbladder Dissection', 'images/gallbladder-dissection.png'],
['4 - Gallbladder Packaging', 'images/gallbladder-packaging.png'],
['5 - Cleaning Coagulation', 'images/cleaning-coagulation.png'],
['6 - Gallbladder Retraction', 'images/gallbladder-retraction.png']],
inputs=[groundtruth_lable, image_input]
)
if __name__ == "__main__":
demo.launch()