Spaces:
Runtime error
Runtime error
# import pytorch related dependencies | |
import torch | |
from torch import nn | |
import numpy as np | |
import torchvision as torchvision | |
import torchvision.transforms as transforms | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import gradio as gr | |
# model setup | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
classes = ['actinic keratoses', 'basal cell carcinoma', 'benign keratosis-like lesions', | |
'dermatofibroma', 'melanoma', 'melanocytic nevi', 'vascular lesions'] | |
model = torchvision.models.mobilenet_v3_large(pretrained=False) | |
model.classifier[3] = nn.Linear(1280, 7) | |
state_dict_trained = torch.load('checkpoints/ham10k_checkpoint_mobile_0.82_epoch24.pt', | |
map_location=torch.device('cpu')) | |
model.load_state_dict(state_dict_trained["model_state_dict"]) | |
model.eval() | |
# image pre-processing | |
norm_mean = (0.4914, 0.4822, 0.4465) | |
norm_std = (0.2023, 0.1994, 0.2010) | |
transform = transforms.Compose([ | |
transforms.CenterCrop((400, 400)), | |
transforms.ToTensor(), | |
transforms.Normalize(norm_mean, norm_std) | |
]) | |
# convert tensot to numpy array | |
def tensor2npimg(tensor, mean, std): | |
# inverse of normalization | |
tensor = tensor.clone() | |
mean_tensor = torch.as_tensor(list(mean), dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) | |
std_tensor = torch.as_tensor(list(std), dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) | |
tensor.mul_(std_tensor).add_(mean_tensor) | |
# convert tensor to numpy format for plt presentation | |
npimg = tensor.numpy() | |
npimg = np.transpose(npimg, (1, 2, 0)) # C*H*W => H*W*C | |
return npimg | |
# draw Grad-CAM on image | |
# target layer could be any layer before the final attention block | |
# Some common choices are: | |
# FasterRCNN: model.backbone | |
# Resnet18 and 50: model.layer4[-1] | |
# VGG and densenet161: model.features[-1] | |
# mnasnet1_0: model.layers[-1] | |
# ViT: model.blocks[-1].norm1 | |
# SwinT: model.layers[-1].blocks[-1].norm1 | |
def image_grad_cam(model, input_tensor, input_float_np, target_layers): | |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
grayscale_cam = cam(input_tensor=input_tensor, aug_smooth=True, eigen_smooth=True) | |
grayscale_cam = grayscale_cam[0, :] | |
return show_cam_on_image(input_float_np, grayscale_cam, use_rgb=True) | |
# config the predict function for Gradio, input type of image is numpy.nparray | |
def predict(gt, input_img): | |
leasion_tensor = transform(input_img) | |
input_float_np = tensor2npimg(leasion_tensor, norm_mean, norm_std) | |
leasion_tensor = leasion_tensor.unsqueeze(dim=0) | |
# predict | |
with torch.no_grad(): | |
outputs = model(leasion_tensor) | |
outputs = torch.exp(outputs) | |
# probabilities of all classes | |
pred_softmax = torch.softmax(outputs, dim=1).cpu().numpy()[0] | |
# class with hightest probability | |
pred = torch.argmax(outputs, dim=1).cpu().numpy() | |
# diagnostic suggestions | |
if pred == 1 or pred == 4: | |
suggestion = "CHECK WITH YOUR MD!" | |
else: | |
suggestion = "Nothing to be worried about." | |
# grad_cam image | |
target_layers = model.features[-1] | |
output_img = image_grad_cam(model, leasion_tensor, input_float_np, target_layers) | |
# return label dict and suggestion | |
return {classes[i]: float(pred_softmax[i]) for i in range(len(classes))}, suggestion, output_img | |
gr.Interface( | |
fn=predict, | |
inputs=[gr.Text(label="Ground Truth"), gr.Image(shape=(400, 400), type="pil", label="Image")], | |
outputs=[gr.Label(label="Predict Result"), gr.Text(label="Recommendation", interactive=False), gr.Image(label="GradCAM")], | |
examples=[['actinic keratoses', 'images/akiec.jpg'], | |
['basal cell carcinoma', 'images/bcc.jpg'], | |
['benign keratosis-like lesions', 'images/bkl.jpg'], | |
['dermatofibroma', 'images/df.jpg'], | |
['melanoma', 'images/mel.jpg'], | |
['melanoma', 'images/mel2.jpg'], | |
['melanoma', 'images/mel3.jpg'], | |
['melanocytic nevi', 'images/nv.jpg'], | |
['melanocytic nevi', 'images/nv2.jpg']], | |
title="Skin Lesion Classifier" | |
).launch() |