### 1. Imports and class names setup ### import gradio as gr import os import torch from model import create_ViT from timeit import default_timer as timer # Setup class names class_names = ['Cancer Negative', 'Malignant Cancer'] ### 2. Model and transforms preparation ### ViT, manual_transforms = create_ViT() # Load saved weights ViT.load_state_dict( torch.load(f='Final_ViT_Model_50_Epochs.pth', map_location=torch.device('cpu') # Load model to cpu ) ) ### 3. Predict function ### def predict(img): start_time = timer() img = manual_transforms(img).unsqueeze(0) ViT.eval() with torch.inference_mode(): pred_prob = torch.sigmoid(ViT(img).squeeze()) pred_label_idx = int(torch.round(pred_prob)) if pred_label_idx: pred_labels_and_preds = {class_names[pred_label_idx-1]: float(pred_prob)} else: if (1-float(pred_prob)) < .7: pred_labels_and_preds = {class_names[pred_label_idx]: 1-float(pred_prob)} else: pred_labels_and_preds = {class_names[pred_label_idx+1]: 1-float(pred_prob)} end_time = timer() pred_time = round(end_time-start_time, 4) return pred_labels_and_preds, pred_time ### 4. Gradio App ### # Create interface for gradio title = 'SkinGuard' description = 'An AI model that can predict whether a CLOSE-UP picture of skin is normal or shows a sign of Skin Cancer.\n\nINSTRUCTIONS:\n\n 1. Take a picture of the area on your skin that you want to check\n\n 2. IMPORTANT: Crop the image to ONLY show the skin lesion / target area. \n\n If you do not follow these instructions, your results may not be as accurate as they could be.' article = 'IMPORTANT NOTE: If you have followed the instructions in the description of this demo and STILL get malignant results constantly, do NOT take it as an official diagnosis. Contact a healthcare professional for more information if you are concerned.' # Create examples list example_list = [['examples/' + example] for example in os.listdir('examples')] # Create the gradio demo demo = gr.Interface(fn=predict,# maps intputs to outputs inputs=gr.Image(type='pil'), outputs=[gr.Label(num_top_classes=1, label='Predictions'), gr.Label(label='Prediction Time (s)')], examples=example_list, title=title, description=description, article=article) # Launch the model!! demo.launch(debug=False, share=True)