Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| # parameters | |
| from models import Cholec80Model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| classes = {"Preparation": 0, | |
| "Calot Triangle Dissection": 1, | |
| "Clipping Cutting": 2, | |
| "Gallbladder Dissection": 3, | |
| "Gallbladder Packaging": 4, | |
| "Cleaning Coagulation": 5, | |
| "Gallbladder Retraction": 6} | |
| # image transformations | |
| mean, std = [0.3456, 0.2281, 0.2233], [0.2528, 0.2135, 0.2104] | |
| transform = transforms.Compose([transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std)]) | |
| # model imports | |
| def load_pretrained_params(model, model_state_path: str): | |
| checkpoint = torch.load(model_state_path, map_location="cpu") | |
| pretrained_dict = checkpoint["state_dict"] | |
| model_dict = model.state_dict() | |
| # 1. filter out unnecessary keys | |
| if list(pretrained_dict.keys())[0].startswith("model."): | |
| pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} | |
| else: | |
| pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
| # 2. overwrite entries in the existing state dict | |
| model_dict.update(pretrained_dict) | |
| # 3. load the new state dict | |
| model.load_state_dict(model_dict) | |
| # 4. eval mode | |
| model.eval() | |
| # 5. put model to device | |
| model.to(device) | |
| cnn_model = Cholec80Model({"image": [2048]}) | |
| load_pretrained_params(cnn_model, "checkpoints/cnn.ckpt") | |
| pe_model = Cholec80Model({"image": [2048, 128], "pos_enc": [7, 7, 128]}) | |
| load_pretrained_params(pe_model, "checkpoints/cnn_pe_2.ckpt") | |
| def cnn(label, image): | |
| # unsqueeze the input_tensor | |
| input_tensor = transform(image) | |
| input_tensor = input_tensor.unsqueeze(dim=0).to(device) | |
| # predict | |
| with torch.no_grad(): | |
| _, output_tensor = cnn_model(input_tensor, {}) | |
| # probabilities of all classes | |
| pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] | |
| # return label dict | |
| return {k: float(pred_softmax[v]) for k, v in classes.items()} | |
| def cnn_mask(label, image, last_phase): | |
| # extract last phase | |
| last_phase = int(last_phase.split("-")[0].strip()) | |
| # mask | |
| masks = [ | |
| [0, 0, -999, -999, -999, -999, -999], | |
| [-999, 0, 0, -999, -999, -999, -999], | |
| [-999, -999, 0, 0, -999, -999, -999], | |
| [-999, -999, -999, 0, 0, 0, -999], | |
| [-999, -999, -999, -999, 0, 0, 0], | |
| [-999, -999, -999, -999, 0, 0, 0], | |
| [-999, -999, -999, -999, -999, 0, 0]] | |
| mask_tensor = torch.tensor([masks[last_phase]]).to(device) | |
| # unsqueeze the input_tensor | |
| input_tensor = transform(image) | |
| input_tensor = input_tensor.unsqueeze(dim=0).to(device) | |
| # predict | |
| with torch.no_grad(): | |
| _, output_tensor = cnn_model(input_tensor, {}) | |
| # probabilities of all classes | |
| pred_softmax = torch.softmax(output_tensor + mask_tensor, dim=1).cpu().numpy()[0] | |
| # return label dict | |
| return {k: float(pred_softmax[v]) for k, v in classes.items()} | |
| def cnn_pe(label, image, p_0, p_1, p_2, p_3, p_4, p_5, p_6): | |
| # form the position encoder vector | |
| pos_enc = torch.Tensor([[p_0, p_1, p_2, p_3, p_4, p_5, p_6]]).to(device) | |
| # unsqueeze the input_tensor | |
| input_tensor = transform(image) | |
| input_tensor = input_tensor.unsqueeze(dim=0).to(device) | |
| # predict | |
| with torch.no_grad(): | |
| _, output_tensor = pe_model(input_tensor, {"pos_enc": pos_enc}) | |
| pred_softmax = torch.softmax(output_tensor, dim=1).cpu().numpy()[0] | |
| # return label dict | |
| return {k: float(pred_softmax[v]) for k, v in classes.items()} | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Phase Recognition of Cholecystectomy Surgeries") | |
| # inputs | |
| with gr.Row(): | |
| with gr.Column(): | |
| groundtruth_lable = gr.Text(label="Ground Truth", interactive=False) | |
| image_input = gr.Image(shape=(255, 255), type="pil") | |
| # output | |
| lable_output = gr.Label(label="Result") | |
| with gr.Tab("CNN") as cnn_tab: | |
| cnn_button = gr.Button("Predict") | |
| cnn_button.click(cnn, inputs=[groundtruth_lable, image_input], outputs=[lable_output]) | |
| with gr.Tab("CNN+Mask") as mask_tab: | |
| phase = gr.Dropdown([f"{v} - {k}" for k, v in classes.items()], label="Last frame is of phase") | |
| mask_button = gr.Button("Predict") | |
| mask_button.click(cnn_mask, inputs=[groundtruth_lable, image_input, phase], outputs=[lable_output]) | |
| with gr.Tab("CNN+PE") as pe_tab: | |
| with gr.Row(): | |
| p0 = gr.Number(label="Phase 0") | |
| p1 = gr.Number(label="Phase 1") | |
| p2 = gr.Number(label="Phase 2") | |
| p3 = gr.Number(label="Phase 3") | |
| p4 = gr.Number(label="Phase 4") | |
| p5 = gr.Number(label="Phase 5") | |
| p6 = gr.Number(label="Phase 6") | |
| pe_button = gr.Button("Predict") | |
| pe_button.click(cnn_pe, inputs=[groundtruth_lable, image_input, p0, p1, p2, p3, p4, p5, p6], outputs=[lable_output]) | |
| gr.Examples( | |
| examples=[['0 - Preparation', 'images/pe-sample.png'], | |
| ['1 - Calot Triangle Dissection', 'images/mask-sample.png'], | |
| ['2 - Clipping Cutting', 'images/clipping-cutting.png'], | |
| ['3 - Gallbladder Dissection', 'images/gallbladder-dissection.png'], | |
| ['4 - Gallbladder Packaging', 'images/gallbladder-packaging.png'], | |
| ['5 - Cleaning Coagulation', 'images/cleaning-coagulation.png'], | |
| ['6 - Gallbladder Retraction', 'images/gallbladder-retraction.png']], | |
| inputs=[groundtruth_lable, image_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |