|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
|
|
|
|
|
|
loaded_model = models.densenet121() |
|
|
|
num_features = loaded_model.classifier.in_features |
|
loaded_model.classifier = nn.Linear(num_features, 5) |
|
loaded_model.load_state_dict(torch.load('derma_diseases_detection_best.pt',map_location=torch.device('cpu'))) |
|
loaded_model.eval() |
|
|
|
|
|
def preprocess_image(image): |
|
|
|
image = Image.fromarray(image) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize([224, 224]), |
|
transforms.ToTensor(), |
|
|
|
]) |
|
image = transform(image) |
|
image = image.unsqueeze(0) |
|
return image |
|
|
|
|
|
def predict_skin_disease(image): |
|
|
|
preprocessed_image = preprocess_image(image) |
|
|
|
|
|
with torch.no_grad(): |
|
output = loaded_model(preprocessed_image) |
|
_, predicted_class = torch.max(output, 1) |
|
|
|
|
|
class_label = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative'] |
|
class_label = class_label[predicted_class.item()] |
|
|
|
|
|
return class_label |
|
|
|
|
|
iface = gr.Interface(fn=predict_skin_disease, inputs="image", outputs="text", live=True) |
|
|
|
|
|
iface.launch() |