SDSdemo / runSDSdemo.py
Tobias Czempiel
workflow stuff
74fe289
# import pytorch related dependencies
import torch
from PIL import Image
from torch import nn
import numpy as np
import torchvision as torchvision
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.image import show_cam_on_image
import gradio as gr
import timm
# model setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = ["Preparation",
"CalotTriangleDissection",
"ClippingCutting",
"GallbladderDissection",
"GallbladderPackaging",
"CleaningCoagulation",
"GallbladderRetraction"]
model = timm.create_model('efficientnet_b3')# This is a very well known network but it is designed for 1000 classes and not just cats and dogs this is why we need the next line
model.classifier = nn.Linear(1536, 7)
#state_dict_trained = torch.hub.load_state_dict_from_url("https://github.com/tobiascz/demotime/raw/main/checkpoints/ham10k_checkpoint_mobile_0.82_epoch24.pt", model_dir=".", map_location = device)
import os
print(os.getcwd())
state_dict_trained = torch.load('checkpoints/state_dict_timm_effnet_b3_e6_val_f1=0.75.pt', map_location=torch.device(device))
sd = model.state_dict()
print(state_dict_trained.keys())
for k,v in sd.items():
if not "classifier" in k:
sd[k] = state_dict_trained[f'model.model.{k}']
sd['classifier.weight'] = state_dict_trained['model.fc_phase.weight']
sd['classifier.bias'] = state_dict_trained['model.fc_phase.bias']
model.load_state_dict(sd) ## Here we load the trained weights (state_dict) in our model
model.eval() # This
# image pre-processing
norm_mean = (0.4914, 0.4822, 0.4465)
norm_std = (0.2023, 0.1994, 0.2010)
transform = transforms.Compose([ # resize image to the network input size
transforms.CenterCrop((400,400)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# convert tensot to numpy array
def tensor2npimg(tensor, mean, std):
# inverse of normalization
tensor = tensor.clone()
mean_tensor = torch.as_tensor(list(mean), dtype=tensor.dtype, device=tensor.device).view(-1,1,1)
std_tensor = torch.as_tensor(list(std), dtype=tensor.dtype, device=tensor.device).view(-1,1,1)
tensor.mul_(std_tensor).add_(mean_tensor)
# convert tensor to numpy format for plt presentation
npimg = tensor.numpy()
npimg = np.transpose(npimg,(1,2,0)) # C*H*W => H*W*C
return npimg
# draw Grad-CAM on image
# target layer could be any layer before the final attention block
# Some common choices are:
# FasterRCNN: model.backbone
# Resnet18 and 50: model.layer4[-1]
# VGG and densenet161: model.features[-1]
# mnasnet1_0: model.layers[-1]
# ViT: model.blocks[-1].norm1
# SwinT: model.layers[-1].blocks[-1].norm1
def image_grad_cam(model, input_tensor, input_float_np, target_layers):
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_tensor, aug_smooth=True, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
return show_cam_on_image(input_float_np, grayscale_cam, use_rgb=True)
# config the predict function for Gradio, input type of image is numpy.nparray
def predict(input_img):
# numpy.nparray -> PIL.Image
leasionExample = Image.fromarray(input_img.astype('uint8'), 'RGB')
# normalize the image to fit the input size of our model
leasion_tensor = transform(leasionExample)
input_float_np = tensor2npimg(leasion_tensor, norm_mean, norm_std)
leasion_tensor = leasion_tensor.unsqueeze(dim=0)
# predict
with torch.no_grad():
outputs = model(leasion_tensor)
outputs = torch.exp(outputs)
# probabilities of all classes
pred_softmax = torch.softmax(outputs, dim=1).cpu().numpy()[0]
# class with hightest probability
# diagnostic suggestions
# return label dict and suggestion
return {classes[i]: float(pred_softmax[i]) for i in range(len(classes))}
# start gradio application
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(),
outputs=[gr.outputs.Label(label="Predict Result")],
examples=[['images/video01_000014_prep.png'],['images/video01_001403.png'],['images/video01_001528_pack.png']],
title="Surgical Workflow Classifier"
).launch()