import gradio as gr import torch from torchvision import transforms # parameters from models import Cholec80Model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') classes = {"Preparation": 0, "Calot Triangle Dissection": 1, "Clipping Cutting": 2, "Gallbladder Dissection": 3, "Gallbladder Packaging": 4, "Cleaning Coagulation": 5, "Gallbladder Retraction": 6} # image transformations mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) # model imports def load_pretrained_params(model, model_state_path: str): checkpoint = torch.load(model_state_path, map_location="cpu") pretrained_dict = checkpoint["state_dict"] model_dict = model.state_dict() # 1. filter out unnecessary keys if list(pretrained_dict.keys())[0].startswith("model."): pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} else: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) # 4. eval mode model.eval() # 5. put model to device model.to(device) cnn_model = Cholec80Model({"image": [2048]}) load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt") pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]}) load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt") def cnn(label, image): # unsqueeze the input_tensor input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(dim=0).to(device) # predict with torch.no_grad(): _, output_tensor = cnn_model(input_tensor, {}) # probabilities of all classes pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] # return label dict return {k: float(pred_softmax[v]) for k, v in classes.items()} def cnn_mask(label, image, last_phase): # extract last phase last_phase = int(last_phase.split("-")[0].strip()) # mask masks = [ [0, 0, -999, -999, -999, -999, -999], [-999, 0, 0, -999, -999, -999, -999], [-999, -999, 0, 0, -999, -999, -999], [-999, -999, -999, 0, 0, 0, -999], [-999, -999, -999, -999, 0, 0, 0], [-999, -999, -999, -999, 0, 0, 0], [-999, -999, -999, -999, -999, 0, 0]] mask_tensor = torch.tensor([masks[last_phase]]).to(device) # unsqueeze the input_tensor input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(dim=0).to(device) # predict with torch.no_grad(): _, output_tensor = cnn_model(input_tensor, {}) # probabilities of all classes pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0] # return label dict return {k: float(pred_softmax[v]) for k, v in classes.items()} def cnn_pe(label, image, p_0, p_1, p_2, p_3, p_4, p_5, p_6): # form the position encoder vector pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device) # unsqueeze the input_tensor input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(dim=0).to(device) # predict with torch.no_grad(): _, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc}) pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] # return label dict return {k: float(pred_softmax[v]) for k, v in classes.items()} with gr.Blocks() as demo: gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries") # inputs with gr.Row(): with gr.Column(): groundtruth_lable = gr.Text(label="Ground Truth", interactive=False) image_input = gr.Image(shape=(255, 255), type="pil") # output lable_output = gr.Label(label="Result") with gr.Tab("CNN") as cnn_tab: cnn_button = gr.Button("Predict") cnn_button.click(cnn, inputs=[groundtruth_lable, image_input], outputs=[lable_output]) with gr.Tab("CNN+Mask") as mask_tab: phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase") mask_button = gr.Button("Predict") mask_button.click(cnn_mask, inputs=[groundtruth_lable, image_input, phase], outputs=[lable_output]) with gr.Tab("CNN+PE") as pe_tab: with gr.Row(): p0 = gr.Number(label="Phase 0") p1 = gr.Number(label="Phase 1") p2 = gr.Number(label="Phase 2") p3 = gr.Number(label="Phase 3") p4 = gr.Number(label="Phase 4") p5 = gr.Number(label="Phase 5") p6 = gr.Number(label="Phase 6") pe_button = gr.Button("Predict") pe_button.click(cnn_pe, inputs=[groundtruth_lable, image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output]) gr.Examples( examples=[['0 - Preparation', 'images/pe-sample.png'], ['1 - Calot Triangle Dissection', 'images/mask-sample.png'], ['2 - Clipping Cutting', 'images/clipping-cutting.png'], ['3 - Gallbladder Dissection', 'images/gallbladder-dissection.png'], ['4 - Gallbladder Packaging', 'images/gallbladder-packaging.png'], ['5 - Cleaning Coagulation', 'images/cleaning-coagulation.png'], ['6 - Gallbladder Retraction', 'images/gallbladder-retraction.png']], inputs=[groundtruth_lable, image_input] ) if __name__ == "__main__": demo.launch()