class-model / app.py
Oualidra's picture
Create app.py
6f6b8aa
import streamlit as st
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# Load the model
loaded_model = models.densenet121()
num_features = loaded_model.classifier.in_features
loaded_model.classifier = nn.Linear(num_features, 5)
loaded_model.load_state_dict(torch.load('derma_diseases_detection_best.pt', map_location=torch.device('cpu')))
loaded_model.eval()
# Define the image preprocessing function
def preprocess_image(image):
image = Image.fromarray(image)
# Transform the image using the same transformations as during training
transform = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.5523, 0.5288, 0.5106], std=[0.1012, 0.0820, 0.0509])
])
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
return image
# Define the prediction function
def predict_skin_disease(image):
# Preprocess the input image
preprocessed_image = preprocess_image(image)
# Make prediction
with torch.no_grad():
output = loaded_model(preprocessed_image)
_, predicted_class = torch.max(output, 1)
# Map the predicted class index to the corresponding class label
class_label = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']
class_label = class_label[predicted_class.item()]
return class_label
# Streamlit app
st.title("Skin Disease Detection")
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
# Display the uploaded image
st.image(uploaded_image, caption="Uploaded Image.", use_column_width=True)
# Convert the image to the format expected by the model
image = Image.open(uploaded_image)
input_image = preprocess_image(image)
# Make prediction
prediction = predict_skin_disease(input_image)
# Display the prediction
st.success(f"Prediction: {prediction}")