File size: 1,994 Bytes
6f6b8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import streamlit as st
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

# Load the model
loaded_model = models.densenet121()

num_features = loaded_model.classifier.in_features
loaded_model.classifier = nn.Linear(num_features, 5)
loaded_model.load_state_dict(torch.load('derma_diseases_detection_best.pt', map_location=torch.device('cpu')))
loaded_model.eval()

# Define the image preprocessing function
def preprocess_image(image):
    image = Image.fromarray(image)
    # Transform the image using the same transformations as during training
    transform = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.5523, 0.5288, 0.5106], std=[0.1012, 0.0820, 0.0509])
    ])
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

# Define the prediction function
def predict_skin_disease(image):
    # Preprocess the input image
    preprocessed_image = preprocess_image(image)

    # Make prediction
    with torch.no_grad():
        output = loaded_model(preprocessed_image)
        _, predicted_class = torch.max(output, 1)

    # Map the predicted class index to the corresponding class label
    class_label = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']
    class_label = class_label[predicted_class.item()]

    return class_label

# Streamlit app
st.title("Skin Disease Detection")

uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image is not None:
    # Display the uploaded image
    st.image(uploaded_image, caption="Uploaded Image.", use_column_width=True)

    # Convert the image to the format expected by the model
    image = Image.open(uploaded_image)
    input_image = preprocess_image(image)

    # Make prediction
    prediction = predict_skin_disease(input_image)

    # Display the prediction
    st.success(f"Prediction: {prediction}")