Juliofc's picture
Update app.py
543f089 verified
import torch
import torch.nn as nn
from torchvision import models, transforms
import gradio as gr
from PIL import Image
# Define the model architecture (must match the saved model)
class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]
def load_model():
model = models.alexnet(pretrained=False)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
# Correctly map the model to CPU if CUDA is not available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device == torch.device('cpu'):
model.load_state_dict(torch.load('model_alexnet.pth', map_location=device))
else:
model.load_state_dict(torch.load('model_alexnet.pth'))
model.eval() # Set to evaluation mode
model.to(device)
return model, device
model_alexnet, device = load_model()
# Image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(device)
with torch.no_grad():
outputs = model_alexnet(image)
_, predicted = torch.max(outputs, 1)
predicted = predicted.cpu().numpy()
return class_names[predicted[0]] # Adjust this if needed
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label",
description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.",
examples=[
'cordana.jpeg',
'healthy.jpeg',
'pestalotiopsis.jpeg',
'sigatoka.jpeg'
]
)
iface.launch()