SkinGuard / app.py
Roofus's picture
Update app.py
f50e252 verified
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from model import create_ViT
from timeit import default_timer as timer
# Setup class names
class_names = ['Cancer Negative', 'Malignant Cancer']
### 2. Model and transforms preparation ###
ViT, manual_transforms = create_ViT()
# Load saved weights
ViT.load_state_dict(
torch.load(f='Final_ViT_Model_50_Epochs.pth',
map_location=torch.device('cpu') # Load model to cpu
)
)
### 3. Predict function ###
def predict(img):
start_time = timer()
img = manual_transforms(img).unsqueeze(0)
ViT.eval()
with torch.inference_mode():
pred_prob = torch.sigmoid(ViT(img).squeeze())
pred_label_idx = int(torch.round(pred_prob))
if pred_label_idx:
pred_labels_and_preds = {class_names[pred_label_idx-1]: float(pred_prob)}
else:
if (1-float(pred_prob)) < .7:
pred_labels_and_preds = {class_names[pred_label_idx]: 1-float(pred_prob)}
else:
pred_labels_and_preds = {class_names[pred_label_idx+1]: 1-float(pred_prob)}
end_time = timer()
pred_time = round(end_time-start_time, 4)
return pred_labels_and_preds, pred_time
### 4. Gradio App ###
# Create interface for gradio
title = 'SkinGuard'
description = 'An AI model that can predict whether a CLOSE-UP picture of skin is normal or shows a sign of Skin Cancer.\n\nINSTRUCTIONS:\n\n 1. Take a picture of the area on your skin that you want to check\n\n 2. IMPORTANT: Crop the image to ONLY show the skin lesion / target area. \n\n If you do not follow these instructions, your results may not be as accurate as they could be.'
article = 'IMPORTANT NOTE: If you have followed the instructions in the description of this demo and STILL get malignant results constantly, do NOT take it as an official diagnosis. Contact a healthcare professional for more information if you are concerned.'
# Create examples list
example_list = [['examples/' + example] for example in os.listdir('examples')]
# Create the gradio demo
demo = gr.Interface(fn=predict,# maps intputs to outputs
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes=1, label='Predictions'),
gr.Label(label='Prediction Time (s)')],
examples=example_list,
title=title,
description=description,
article=article)
# Launch the model!!
demo.launch(debug=False,
share=True)