Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torchvision import transforms | |
# parameters | |
from models import Cholec80Model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
classes = {"Preparation": 0, | |
"Calot Triangle Dissection": 1, | |
"Clipping Cutting": 2, | |
"Gallbladder Dissection": 3, | |
"Gallbladder Packaging": 4, | |
"Cleaning Coagulation": 5, | |
"Gallbladder Retraction": 6} | |
# image transformations | |
mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104] | |
transform = transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std)]) | |
# model imports | |
def load_pretrained_params(model, model_state_path: str): | |
checkpoint = torch.load(model_state_path, map_location="cpu") | |
pretrained_dict = checkpoint["state_dict"] | |
model_dict = model.state_dict() | |
# 1. filter out unnecessary keys | |
if list(pretrained_dict.keys())[0].startswith("model."): | |
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} | |
else: | |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
# 2. overwrite entries in the existing state dict | |
model_dict.update(pretrained_dict) | |
# 3. load the new state dict | |
model.load_state_dict(model_dict) | |
# 4. eval mode | |
model.eval() | |
# 5. put model to device | |
model.to(device) | |
cnn_model = Cholec80Model({"image": [2048]}) | |
load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt") | |
pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]}) | |
load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt") | |
def cnn(label, image): | |
# unsqueeze the input_tensor | |
input_tensor = transform(image) | |
input_tensor = input_tensor.unsqueeze(dim=0).to(device) | |
# predict | |
with torch.no_grad(): | |
_, output_tensor = cnn_model(input_tensor, {}) | |
# probabilities of all classes | |
pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] | |
# return label dict | |
return {k: float(pred_softmax[v]) for k, v in classes.items()} | |
def cnn_mask(label, image, last_phase): | |
# extract last phase | |
last_phase = int(last_phase.split("-")[0].strip()) | |
# mask | |
masks = [ | |
[0, 0, -999, -999, -999, -999, -999], | |
[-999, 0, 0, -999, -999, -999, -999], | |
[-999, -999, 0, 0, -999, -999, -999], | |
[-999, -999, -999, 0, 0, 0, -999], | |
[-999, -999, -999, -999, 0, 0, 0], | |
[-999, -999, -999, -999, 0, 0, 0], | |
[-999, -999, -999, -999, -999, 0, 0]] | |
mask_tensor = torch.tensor([masks[last_phase]]).to(device) | |
# unsqueeze the input_tensor | |
input_tensor = transform(image) | |
input_tensor = input_tensor.unsqueeze(dim=0).to(device) | |
# predict | |
with torch.no_grad(): | |
_, output_tensor = cnn_model(input_tensor, {}) | |
# probabilities of all classes | |
pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0] | |
# return label dict | |
return {k: float(pred_softmax[v]) for k, v in classes.items()} | |
def cnn_pe(label, image, p_0, p_1, p_2, p_3, p_4, p_5, p_6): | |
# form the position encoder vector | |
pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device) | |
# unsqueeze the input_tensor | |
input_tensor = transform(image) | |
input_tensor = input_tensor.unsqueeze(dim=0).to(device) | |
# predict | |
with torch.no_grad(): | |
_, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc}) | |
pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] | |
# return label dict | |
return {k: float(pred_softmax[v]) for k, v in classes.items()} | |
with gr.Blocks() as demo: | |
gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries") | |
# inputs | |
with gr.Row(): | |
with gr.Column(): | |
groundtruth_lable = gr.Text(label="Ground Truth", interactive=False) | |
image_input = gr.Image(shape=(255, 255), type="pil") | |
# output | |
lable_output = gr.Label(label="Result") | |
with gr.Tab("CNN") as cnn_tab: | |
cnn_button = gr.Button("Predict") | |
cnn_button.click(cnn, inputs=[groundtruth_lable, image_input], outputs=[lable_output]) | |
with gr.Tab("CNN+Mask") as mask_tab: | |
phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase") | |
mask_button = gr.Button("Predict") | |
mask_button.click(cnn_mask, inputs=[groundtruth_lable, image_input, phase], outputs=[lable_output]) | |
with gr.Tab("CNN+PE") as pe_tab: | |
with gr.Row(): | |
p0 = gr.Number(label="Phase 0") | |
p1 = gr.Number(label="Phase 1") | |
p2 = gr.Number(label="Phase 2") | |
p3 = gr.Number(label="Phase 3") | |
p4 = gr.Number(label="Phase 4") | |
p5 = gr.Number(label="Phase 5") | |
p6 = gr.Number(label="Phase 6") | |
pe_button = gr.Button("Predict") | |
pe_button.click(cnn_pe, inputs=[groundtruth_lable, image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output]) | |
gr.Examples( | |
examples=[['0 - Preparation', 'images/pe-sample.png'], | |
['1 - Calot Triangle Dissection', 'images/mask-sample.png'], | |
['2 - Clipping Cutting', 'images/clipping-cutting.png'], | |
['3 - Gallbladder Dissection', 'images/gallbladder-dissection.png'], | |
['4 - Gallbladder Packaging', 'images/gallbladder-packaging.png'], | |
['5 - Cleaning Coagulation', 'images/cleaning-coagulation.png'], | |
['6 - Gallbladder Retraction', 'images/gallbladder-retraction.png']], | |
inputs=[groundtruth_lable, image_input] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |