File size: 2,603 Bytes
8f4809f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f24488
8f4809f
1f24488
 
d79ba99
1f24488
8f4809f
 
 
 
 
 
 
 
a3a588d
 
8f4809f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch

from model import create_ViT
from timeit import default_timer as timer

# Setup class names
class_names = ['Normal', 'Malignant']

### 2. Model and transforms preparation ###
ViT, manual_transforms = create_ViT()

# Load saved weights
ViT.load_state_dict(
  torch.load(f='Final_ViT_Model_50_Epochs.pth',
  map_location=torch.device('cpu') # Load model to cpu
  )
)

### 3. Predict function ###
def predict(img):
  start_time = timer()

  img = manual_transforms(img).unsqueeze(0)

  ViT.eval()
  with torch.inference_mode():
      pred_prob = torch.sigmoid(ViT(img).squeeze())
      pred_label_idx = int(torch.round(pred_prob))
  if pred_label_idx:
    pred_labels_and_preds = {class_names[pred_label_idx-1]: float(pred_prob)}
  else:
    if (1-float(pred_prob)) < .75:
      pred_labels_and_preds = {class_names[pred_label_idx]: 1-float(pred_prob)}
    else:
      pred_labels_and_preds = {class_names[pred_label_idx+1]: 1-float(pred_prob)}
  end_time = timer()
  pred_time = round(end_time-start_time, 4)

  return pred_labels_and_preds, pred_time

### 4. Gradio App ###
# Create interface for gradio
title = 'SkinGuard'
description = 'An AI model that can predict whether a CLOSE-UP picture of skin is normal or shows a sign of Skin Cancer.\n\nINSTRUCTIONS:\n\n 1. Take a picture of the area on your skin that you want to check\n\n 2. IMPORTANT: Crop the image to ONLY show the skin lesion / target area. \n\n If you do not follow these instructions, your results may not be as accurate as they could be.'
article = 'This is our demo for the DEDA Entrepreneurship Competition 2024. Created by Imran, Brian, Lukas, and Rohit, all in Baez CSE Period 5. \n\n IMPORTANT NOTE: If you have followed the instructions in the description of this demo and STILL get malignant results constantly, do NOT take it as an official diagnosis. Contact a healthcare professional for more information if you are concerned.'
# Create examples list
example_list = [['examples/' + example] for example in os.listdir('examples')]

# Create the gradio demo
demo = gr.Interface(fn=predict,# maps intputs to outputs
                    inputs=gr.Image(type='pil'),
                    outputs=[gr.Label(num_top_classes=1, label='Predictions'),
                             gr.Label(label='Prediction Time (s)')],
                    examples=example_list,
                    title=title,
                    description=description,
                    article=article)
# Launch the model!!
demo.launch(debug=False,
            share=True)