import gradio as gr import torch from torchvision import transforms # parameters from models import Cholec80Model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') classes = {"Preparation": 0, "Calot Triangle Dissection": 1, "Clipping Cutting": 2, "Gallbladder Dissection": 3, "Gallbladder Packaging": 4, "Cleaning Coagulation": 5, "Gallbladder Retraction": 6} # image transformations mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) # model imports def load_pretrained_params(model, model_state_path: str): pretrained_dict = torch.load(model_state_path, map_location="cpu") model_dict = model.state_dict() # 1. filter out unnecessary keys if list(pretrained_dict.keys())[0].startswith("module."): pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} else: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) # 4. eval mode model.eval() # 5. put model to device model.to(device) cnn_model = Cholec80Model({"image": [2048]}) load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt") pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]}) load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt") def cnn(image): # unsqueeze the input_tensor input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(dim=0).to(device) # predict with torch.no_grad(): _, output_tensor = cnn_model(input_tensor, {}) # probabilities of all classes pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] # return label dict return {k: float(pred_softmax[v]) for k, v in classes.items()} def cnn_mask(image, last_phase): # extract last phase last_phase = int(last_phase.split("-")[0].strip()) # mask masks = [ [0, 0, -999, -999, -999, -999, -999], [-999, 0, 0, -999, -999, -999, -999], [-999, -999, 0, 0, -999, -999, -999], [-999, -999, -999, 0, 0, 0, -999], [-999, -999, -999, -999, 0, 0, 0], [-999, -999, -999, -999, 0, 0, 0], [-999, -999, -999, -999, -999, 0, 0]] mask_tensor = torch.tensor([masks[last_phase]]).to(device) # unsqueeze the input_tensor input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(dim=0).to(device) # predict with torch.no_grad(): _, output_tensor = cnn_model(input_tensor, {}) # probabilities of all classes pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0] # return label dict return {k: float(pred_softmax[v]) for k, v in classes.items()} def cnn_pe(image, p_0, p_1, p_2, p_3, p_4, p_5, p_6): # form the position encoder vector pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device) # unsqueeze the input_tensor input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(dim=0).to(device) # predict with torch.no_grad(): _, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc}) pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] # return label dict return {k: float(pred_softmax[v]) for k, v in classes.items()} with gr.Blocks() as demo: gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries") # inputs with gr.Row(): image_input = gr.Image(shape=(255, 255), type="pil") # output lable_output = gr.Label() with gr.Tab("CNN") as cnn_tab: cnn_button = gr.Button("Predict") cnn_button.click(cnn, inputs=[image_input], outputs=[lable_output]) with gr.Tab("CNN+Mask") as mask_tab: phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase") mask_button = gr.Button("Predict") mask_button.click(cnn_mask, inputs=[image_input, phase], outputs=[lable_output]) with gr.Tab("CNN+PE") as pe_tab: with gr.Row(): p0 = gr.Number(label="Phase 0") p1 = gr.Number(label="Phase 1") p2 = gr.Number(label="Phase 2") p3 = gr.Number(label="Phase 3") p4 = gr.Number(label="Phase 4") p5 = gr.Number(label="Phase 5") p6 = gr.Number(label="Phase 6") pe_button = gr.Button("Predict") pe_button.click(cnn_pe, inputs=[image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output]) gr.Examples( examples=[['images/preparation.png'], ['images/calot-triangle-dissection.png'], ['images/clipping-cutting.png'], ['images/gallbladder-dissection.png'], ['images/gallbladder-packaging.png'], ['images/cleaning-coagulation.png'], ['images/gallbladder-retraction.png']], inputs=image_input ) if __name__ == "__main__": demo.launch(share=True)