Oualidra commited on
Commit
6f6b8aa
1 Parent(s): 55ef918

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+
7
+ # Load the model
8
+ loaded_model = models.densenet121()
9
+
10
+ num_features = loaded_model.classifier.in_features
11
+ loaded_model.classifier = nn.Linear(num_features, 5)
12
+ loaded_model.load_state_dict(torch.load('derma_diseases_detection_best.pt', map_location=torch.device('cpu')))
13
+ loaded_model.eval()
14
+
15
+ # Define the image preprocessing function
16
+ def preprocess_image(image):
17
+ image = Image.fromarray(image)
18
+ # Transform the image using the same transformations as during training
19
+ transform = transforms.Compose([
20
+ transforms.Resize([224, 224]),
21
+ transforms.ToTensor(),
22
+ #transforms.Normalize(mean=[0.5523, 0.5288, 0.5106], std=[0.1012, 0.0820, 0.0509])
23
+ ])
24
+ image = transform(image)
25
+ image = image.unsqueeze(0) # Add batch dimension
26
+ return image
27
+
28
+ # Define the prediction function
29
+ def predict_skin_disease(image):
30
+ # Preprocess the input image
31
+ preprocessed_image = preprocess_image(image)
32
+
33
+ # Make prediction
34
+ with torch.no_grad():
35
+ output = loaded_model(preprocessed_image)
36
+ _, predicted_class = torch.max(output, 1)
37
+
38
+ # Map the predicted class index to the corresponding class label
39
+ class_label = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']
40
+ class_label = class_label[predicted_class.item()]
41
+
42
+ return class_label
43
+
44
+ # Streamlit app
45
+ st.title("Skin Disease Detection")
46
+
47
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
48
+
49
+ if uploaded_image is not None:
50
+ # Display the uploaded image
51
+ st.image(uploaded_image, caption="Uploaded Image.", use_column_width=True)
52
+
53
+ # Convert the image to the format expected by the model
54
+ image = Image.open(uploaded_image)
55
+ input_image = preprocess_image(image)
56
+
57
+ # Make prediction
58
+ prediction = predict_skin_disease(input_image)
59
+
60
+ # Display the prediction
61
+ st.success(f"Prediction: {prediction}")