test / app.py
Oualidra's picture
Update app.py
0765a71
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# Load the model
loaded_model = models.densenet121()
num_features = loaded_model.classifier.in_features
loaded_model.classifier = nn.Linear(num_features, 5)
loaded_model.load_state_dict(torch.load('derma_diseases_detection_best.pt',map_location=torch.device('cpu')))
loaded_model.eval()
# Define the image preprocessing function
def preprocess_image(image):
image = Image.fromarray(image)
# Transform the image using the same transformations as during training
transform = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.5523, 0.5288, 0.5106], std=[0.1012, 0.0820, 0.0509])
])
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
return image
# Define the prediction function
def predict_skin_disease(image):
# Preprocess the input image
preprocessed_image = preprocess_image(image)
# Make prediction
with torch.no_grad():
output = loaded_model(preprocessed_image)
_, predicted_class = torch.max(output, 1)
# Map the predicted class index to the corresponding class label
class_label = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']
class_label = class_label[predicted_class.item()]
return class_label
# Create a Gradio interface
iface = gr.Interface(fn=predict_skin_disease, inputs="image", outputs="text", live=True)
# Launch the Gradio interface
iface.launch()